Skip to content

Commit bc7a268

Browse files
SunMarcYangKai0616MekkCyber
authored
Fix fp8 + some enhancement (#42455)
* Fix fp8 + some enhancement * style * Add coauthor Co-authored-by: Yang Kai <[email protected]> * fix * style * fix tests * style * assertin * style * fix * fix * Apply suggestions from code review Co-authored-by: Mohamed Mekkouri <[email protected]> --------- Co-authored-by: Yang Kai <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]>
1 parent 29e8522 commit bc7a268

File tree

8 files changed

+138
-271
lines changed

8 files changed

+138
-271
lines changed

src/transformers/core_model_loading.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ def convert_and_load_state_dict_in_model(
826826
if hf_quantizer and hf_quantizer.pre_quantized and original_key != renamed_key:
827827
# if the key was renamed as it is not available in the state dict otherwise, it means that we are deserializing it,
828828
# so we need to make sure to load the tensor with the same dtype from the checkpoint
829+
# TODO: make the condition more srict for native fp8 model such as qwen2moe fp8
829830
_dtype = None
830831
elif dtype_plan != {} and dtype_policy_alt.search(renamed_key):
831832
matched_dtype_pattern = dtype_policy_alt.search(renamed_key)

src/transformers/integrations/finegrained_fp8.py

Lines changed: 61 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import re
17-
from typing import Optional
18-
1916
from ..core_model_loading import ConversionOps
17+
from ..quantizers.quantizers_utils import should_convert_module
2018
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
2119

2220

@@ -307,44 +305,38 @@ def w8a8_block_fp8_matmul_compile(
307305

308306

309307
class FP8Linear(nn.Linear):
310-
dtype = torch.float8_e4m3fn
311-
312308
def __init__(
313309
self,
314310
in_features: int,
315311
out_features: int,
316312
bias: bool = False,
317-
dtype=None,
313+
dtype=torch.float8_e4m3fn,
318314
block_size: tuple[int, int] | None = None,
319-
device=None,
320315
activation_scheme="dynamic",
321316
):
322317
super().__init__(in_features, out_features)
323-
self.in_features = in_features
324-
self.out_features = out_features
325318

319+
# If block size, is not passed, it means that we are doing per-tensor quantization
326320
if block_size is not None:
327321
self.block_size = block_size
328322
else:
329323
self.block_size = (out_features, in_features)
330324

331-
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
325+
self.activation_scheme = activation_scheme
332326

333-
if self.weight.element_size() == 1:
334-
scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
335-
scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
336-
if scale_out_features * scale_in_features == 1:
337-
self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
338-
else:
339-
self.weight_scale_inv = nn.Parameter(
340-
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
341-
)
327+
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
328+
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
329+
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
330+
331+
if scale_out_features * scale_in_features == 1:
332+
self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
342333
else:
343-
self.register_parameter("weight_scale_inv", None)
344-
self.activation_scheme = activation_scheme
334+
self.weight_scale_inv = nn.Parameter(
335+
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
336+
)
345337

346338
if self.activation_scheme == "static":
347-
self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
339+
self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
348340

349341
if bias:
350342
self.bias = nn.Parameter(torch.empty(self.out_features))
@@ -400,9 +392,7 @@ def _ceil_div(a, b):
400392

401393

402394
class FP8Expert(nn.Module):
403-
dtype = torch.float8_e4m3fn
404-
405-
def __init__(self, config, block_size, device):
395+
def __init__(self, config, block_size, dtype=torch.float8_e4m3fn):
406396
super().__init__()
407397

408398
from ..activations import ACT2FN
@@ -415,34 +405,24 @@ def __init__(self, config, block_size, device):
415405
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
416406
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
417407

418-
self.gate_up_proj = nn.Parameter(
419-
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
420-
)
421-
self.down_proj = nn.Parameter(
422-
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
423-
)
408+
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype))
409+
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype))
424410

425-
# Create inverse scale tiles only when using 1-byte types (fp8)
426-
if self.gate_up_proj.element_size() == 1:
427-
bo, bi = self.block_size
411+
bo, bi = self.block_size
428412

429-
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
430-
gu_scale_o = _ceil_div(Wg_out, bo)
431-
gu_scale_i = _ceil_div(Wg_in, bi)
432-
self.gate_up_proj_scale_inv = nn.Parameter(
433-
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
434-
)
413+
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
414+
gu_scale_o = _ceil_div(Wg_out, bo)
415+
gu_scale_i = _ceil_div(Wg_in, bi)
416+
self.gate_up_proj_scale_inv = nn.Parameter(
417+
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32)
418+
)
435419

436-
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
437-
dp_scale_o = _ceil_div(Wd_out, bo)
438-
dp_scale_i = _ceil_div(Wd_in, bi)
439-
self.down_proj_scale_inv = nn.Parameter(
440-
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
441-
)
442-
else:
443-
# Match FP8Linear behavior when not using 1-byte weights
444-
self.register_parameter("gate_up_proj_scale_inv", None)
445-
self.register_parameter("down_proj_scale_inv", None)
420+
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
421+
dp_scale_o = _ceil_div(Wd_out, bo)
422+
dp_scale_i = _ceil_div(Wd_in, bi)
423+
self.down_proj_scale_inv = nn.Parameter(
424+
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32)
425+
)
446426

447427
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
448428
self.register_parameter("gate_up_bias", None)
@@ -508,90 +488,46 @@ def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: to
508488
return output.to(dtype=input.dtype)
509489

510490

511-
# TODO: we do need this.... but not recursive...
512-
def _replace_with_fp8_linear(
513-
model,
514-
tp_plan=None,
515-
modules_to_not_convert=None,
516-
current_key_name=None,
517-
quantization_config=None,
518-
has_been_replaced=False,
519-
):
520-
iterator = list(model.named_parameters()).copy()
521-
for name, empty_tensor in iterator:
522-
current_key_name = name
523-
name = name.rsplit(".", 1)[0] if "." in name else name
524-
module = model.get_submodule(name)
525-
526-
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
527-
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
528-
with init_empty_weights():
529-
if (
530-
"gate_up_proj" in current_key_name
531-
or "down_proj" in current_key_name
532-
and "experts" in current_key_name
533-
): # Experts!
534-
in_features = empty_tensor.size(-2)
535-
out_features = empty_tensor.size(-1)
536-
model.set_submodule(
537-
name,
538-
FP8Expert(
539-
config=model.config,
540-
block_size=quantization_config.weight_block_size,
541-
device=empty_tensor.device,
542-
),
543-
)
544-
545-
elif isinstance(module, nn.Linear):
546-
in_features = module.in_features
547-
out_features = module.out_features
548-
model.set_submodule(
549-
name,
550-
FP8Linear(
551-
in_features=in_features,
552-
out_features=out_features,
553-
bias=module.bias is not None,
554-
device=module.weight.device,
555-
dtype=module.weight.dtype,
556-
activation_scheme=quantization_config.activation_scheme,
557-
block_size=quantization_config.weight_block_size,
558-
),
559-
)
560-
has_been_replaced = True
561-
# when changing a layer the TP PLAN for that layer should be updated. TODO
562-
563-
return model, has_been_replaced
564-
565-
566491
def replace_with_fp8_linear(
567492
model,
568493
modules_to_not_convert=None,
569494
quantization_config=None,
495+
pre_quantized=False,
570496
):
571497
"""Helper function to replace model layers with FP8 versions."""
572498
if quantization_config.dequantize:
573499
return model
574500

575-
if modules_to_not_convert is None:
576-
modules_to_not_convert = []
577-
modules_to_not_convert += ["lm_head"]
578-
579-
if quantization_config.modules_to_not_convert is not None:
580-
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
581-
modules_to_not_convert = list(set(modules_to_not_convert))
582-
model, has_been_replaced = _replace_with_fp8_linear(
583-
model,
584-
tp_plan=model._tp_plan,
585-
modules_to_not_convert=modules_to_not_convert,
586-
quantization_config=quantization_config,
587-
)
501+
has_been_replaced = False
502+
for module_name, module in model.named_modules():
503+
if not should_convert_module(module_name, modules_to_not_convert):
504+
continue
505+
# we need this to correctly materialize the weights during quantization
506+
module_kwargs = {} if pre_quantized else {"dtype": None}
507+
new_module = None
508+
with init_empty_weights():
509+
if "gate_up_proj" in module_name or "down_proj" in module_name and "experts" in module_name:
510+
new_module = FP8Expert(
511+
config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
512+
)
513+
elif isinstance(module, nn.Linear):
514+
new_module = FP8Linear(
515+
in_features=module.in_features,
516+
out_features=module.out_features,
517+
bias=module.bias is not None,
518+
activation_scheme=quantization_config.activation_scheme,
519+
block_size=quantization_config.weight_block_size,
520+
**module_kwargs,
521+
)
522+
if new_module is not None:
523+
model.set_submodule(module_name, new_module)
524+
has_been_replaced = True
588525

589526
if not has_been_replaced:
590527
logger.warning(
591528
"You are loading your model using fp8 but no linear modules were found in your model."
592529
" Please double check your model architecture."
593530
)
594-
595531
return model
596532

597533

@@ -606,7 +542,7 @@ def __init__(self, hf_quantizer):
606542
def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
607543
# Unpack single key/value (value may be wrapped in a list)
608544
target_keys, value = tuple(input_dict.items())[0]
609-
value = value[0] if isinstance(value, list) else value
545+
value = value[0]
610546

611547
# Resolve block size (support dict-like or attr-like quant_config)
612548
block_size = None
@@ -681,24 +617,15 @@ def __init__(self, hf_quantizer):
681617
def convert(
682618
self,
683619
input_dict: dict[str, torch.Tensor],
684-
model: Optional[torch.nn.Module] = None,
685620
full_layer_name: str | None = None,
686-
missing_keys=None,
687621
**kwargs,
688622
) -> dict[str, torch.Tensor]:
689623
if len(input_dict) < 2:
690-
# in case of no scales, the weights are not quantized, so we return the weights as is
691-
return {
692-
full_layer_name: input_dict["weight$"][0]
693-
if isinstance(input_dict["weight$"], list)
694-
else input_dict["weight$"]
695-
}
696-
quantized = input_dict["weight$"][0] if isinstance(input_dict["weight$"], list) else input_dict["weight$"]
697-
scales = (
698-
input_dict["weight_scale_inv"][0]
699-
if isinstance(input_dict["weight_scale_inv"], list)
700-
else input_dict["weight_scale_inv"]
701-
)
624+
# case where we only got weights, need to check for "weight$"
625+
return {full_layer_name: input_dict["weight$"]}
626+
627+
quantized = input_dict["weight$"][0]
628+
scales = input_dict["weight_scale_inv"][0]
702629

703630
rows, cols = quantized.shape[-2:]
704631
block_size = self.hf_quantizer.quantization_config.weight_block_size

0 commit comments

Comments
 (0)