1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import re
17- from typing import Optional
18-
1916from ..core_model_loading import ConversionOps
17+ from ..quantizers .quantizers_utils import should_convert_module
2018from ..utils import is_accelerate_available , is_torch_accelerator_available , is_torch_available , logging
2119
2220
@@ -307,44 +305,38 @@ def w8a8_block_fp8_matmul_compile(
307305
308306
309307class FP8Linear (nn .Linear ):
310- dtype = torch .float8_e4m3fn
311-
312308 def __init__ (
313309 self ,
314310 in_features : int ,
315311 out_features : int ,
316312 bias : bool = False ,
317- dtype = None ,
313+ dtype = torch . float8_e4m3fn ,
318314 block_size : tuple [int , int ] | None = None ,
319- device = None ,
320315 activation_scheme = "dynamic" ,
321316 ):
322317 super ().__init__ (in_features , out_features )
323- self .in_features = in_features
324- self .out_features = out_features
325318
319+ # If block size, is not passed, it means that we are doing per-tensor quantization
326320 if block_size is not None :
327321 self .block_size = block_size
328322 else :
329323 self .block_size = (out_features , in_features )
330324
331- self .weight = torch . nn . Parameter ( torch . empty ( out_features , in_features , dtype = FP8Linear . dtype , device = device ))
325+ self .activation_scheme = activation_scheme
332326
333- if self .weight .element_size () == 1 :
334- scale_out_features = (out_features + self .block_size [0 ] - 1 ) // self .block_size [0 ]
335- scale_in_features = (in_features + self .block_size [1 ] - 1 ) // self .block_size [1 ]
336- if scale_out_features * scale_in_features == 1 :
337- self .weight_scale_inv = nn .Parameter (torch .tensor (1.0 , dtype = torch .float32 , device = device ))
338- else :
339- self .weight_scale_inv = nn .Parameter (
340- torch .empty (scale_out_features , scale_in_features , dtype = torch .float32 , device = device )
341- )
327+ self .weight = torch .nn .Parameter (torch .empty (out_features , in_features , dtype = dtype ))
328+ scale_out_features = (out_features + block_size [0 ] - 1 ) // block_size [0 ]
329+ scale_in_features = (in_features + block_size [1 ] - 1 ) // block_size [1 ]
330+
331+ if scale_out_features * scale_in_features == 1 :
332+ self .weight_scale_inv = nn .Parameter (torch .tensor (1.0 , dtype = torch .float32 ))
342333 else :
343- self .register_parameter ("weight_scale_inv" , None )
344- self .activation_scheme = activation_scheme
334+ self .weight_scale_inv = nn .Parameter (
335+ torch .empty (scale_out_features , scale_in_features , dtype = torch .float32 )
336+ )
345337
346338 if self .activation_scheme == "static" :
347- self .activation_scale = nn .Parameter (torch .tensor (1.0 , dtype = torch .float32 , device = device ))
339+ self .activation_scale = nn .Parameter (torch .tensor (1.0 , dtype = torch .float32 ))
348340
349341 if bias :
350342 self .bias = nn .Parameter (torch .empty (self .out_features ))
@@ -400,9 +392,7 @@ def _ceil_div(a, b):
400392
401393
402394class FP8Expert (nn .Module ):
403- dtype = torch .float8_e4m3fn
404-
405- def __init__ (self , config , block_size , device ):
395+ def __init__ (self , config , block_size , dtype = torch .float8_e4m3fn ):
406396 super ().__init__ ()
407397
408398 from ..activations import ACT2FN
@@ -415,34 +405,24 @@ def __init__(self, config, block_size, device):
415405 Wg_out , Wg_in = 2 * self .intermediate_dim , self .hidden_dim
416406 Wd_out , Wd_in = self .hidden_dim , self .intermediate_dim
417407
418- self .gate_up_proj = nn .Parameter (
419- torch .zeros (self .num_experts , Wg_out , Wg_in , dtype = FP8Expert .dtype , device = device )
420- )
421- self .down_proj = nn .Parameter (
422- torch .zeros (self .num_experts , Wd_out , Wd_in , dtype = FP8Expert .dtype , device = device )
423- )
408+ self .gate_up_proj = nn .Parameter (torch .zeros (self .num_experts , Wg_out , Wg_in , dtype = dtype ))
409+ self .down_proj = nn .Parameter (torch .zeros (self .num_experts , Wd_out , Wd_in , dtype = dtype ))
424410
425- # Create inverse scale tiles only when using 1-byte types (fp8)
426- if self .gate_up_proj .element_size () == 1 :
427- bo , bi = self .block_size
411+ bo , bi = self .block_size
428412
429- # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
430- gu_scale_o = _ceil_div (Wg_out , bo )
431- gu_scale_i = _ceil_div (Wg_in , bi )
432- self .gate_up_proj_scale_inv = nn .Parameter (
433- torch .zeros (self .num_experts , gu_scale_o , gu_scale_i , dtype = torch .float32 , device = device )
434- )
413+ # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
414+ gu_scale_o = _ceil_div (Wg_out , bo )
415+ gu_scale_i = _ceil_div (Wg_in , bi )
416+ self .gate_up_proj_scale_inv = nn .Parameter (
417+ torch .zeros (self .num_experts , gu_scale_o , gu_scale_i , dtype = torch .float32 )
418+ )
435419
436- # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
437- dp_scale_o = _ceil_div (Wd_out , bo )
438- dp_scale_i = _ceil_div (Wd_in , bi )
439- self .down_proj_scale_inv = nn .Parameter (
440- torch .zeros (self .num_experts , dp_scale_o , dp_scale_i , dtype = torch .float32 , device = device )
441- )
442- else :
443- # Match FP8Linear behavior when not using 1-byte weights
444- self .register_parameter ("gate_up_proj_scale_inv" , None )
445- self .register_parameter ("down_proj_scale_inv" , None )
420+ # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
421+ dp_scale_o = _ceil_div (Wd_out , bo )
422+ dp_scale_i = _ceil_div (Wd_in , bi )
423+ self .down_proj_scale_inv = nn .Parameter (
424+ torch .zeros (self .num_experts , dp_scale_o , dp_scale_i , dtype = torch .float32 )
425+ )
446426
447427 # (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
448428 self .register_parameter ("gate_up_bias" , None )
@@ -508,90 +488,46 @@ def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: to
508488 return output .to (dtype = input .dtype )
509489
510490
511- # TODO: we do need this.... but not recursive...
512- def _replace_with_fp8_linear (
513- model ,
514- tp_plan = None ,
515- modules_to_not_convert = None ,
516- current_key_name = None ,
517- quantization_config = None ,
518- has_been_replaced = False ,
519- ):
520- iterator = list (model .named_parameters ()).copy ()
521- for name , empty_tensor in iterator :
522- current_key_name = name
523- name = name .rsplit ("." , 1 )[0 ] if "." in name else name
524- module = model .get_submodule (name )
525-
526- current_key_name_str = re .sub (r"\d+" , "*" , current_key_name )
527- if not any (key in current_key_name_str for key in (modules_to_not_convert or [])):
528- with init_empty_weights ():
529- if (
530- "gate_up_proj" in current_key_name
531- or "down_proj" in current_key_name
532- and "experts" in current_key_name
533- ): # Experts!
534- in_features = empty_tensor .size (- 2 )
535- out_features = empty_tensor .size (- 1 )
536- model .set_submodule (
537- name ,
538- FP8Expert (
539- config = model .config ,
540- block_size = quantization_config .weight_block_size ,
541- device = empty_tensor .device ,
542- ),
543- )
544-
545- elif isinstance (module , nn .Linear ):
546- in_features = module .in_features
547- out_features = module .out_features
548- model .set_submodule (
549- name ,
550- FP8Linear (
551- in_features = in_features ,
552- out_features = out_features ,
553- bias = module .bias is not None ,
554- device = module .weight .device ,
555- dtype = module .weight .dtype ,
556- activation_scheme = quantization_config .activation_scheme ,
557- block_size = quantization_config .weight_block_size ,
558- ),
559- )
560- has_been_replaced = True
561- # when changing a layer the TP PLAN for that layer should be updated. TODO
562-
563- return model , has_been_replaced
564-
565-
566491def replace_with_fp8_linear (
567492 model ,
568493 modules_to_not_convert = None ,
569494 quantization_config = None ,
495+ pre_quantized = False ,
570496):
571497 """Helper function to replace model layers with FP8 versions."""
572498 if quantization_config .dequantize :
573499 return model
574500
575- if modules_to_not_convert is None :
576- modules_to_not_convert = []
577- modules_to_not_convert += ["lm_head" ]
578-
579- if quantization_config .modules_to_not_convert is not None :
580- modules_to_not_convert .extend (quantization_config .modules_to_not_convert )
581- modules_to_not_convert = list (set (modules_to_not_convert ))
582- model , has_been_replaced = _replace_with_fp8_linear (
583- model ,
584- tp_plan = model ._tp_plan ,
585- modules_to_not_convert = modules_to_not_convert ,
586- quantization_config = quantization_config ,
587- )
501+ has_been_replaced = False
502+ for module_name , module in model .named_modules ():
503+ if not should_convert_module (module_name , modules_to_not_convert ):
504+ continue
505+ # we need this to correctly materialize the weights during quantization
506+ module_kwargs = {} if pre_quantized else {"dtype" : None }
507+ new_module = None
508+ with init_empty_weights ():
509+ if "gate_up_proj" in module_name or "down_proj" in module_name and "experts" in module_name :
510+ new_module = FP8Expert (
511+ config = model .config , block_size = quantization_config .weight_block_size , ** module_kwargs
512+ )
513+ elif isinstance (module , nn .Linear ):
514+ new_module = FP8Linear (
515+ in_features = module .in_features ,
516+ out_features = module .out_features ,
517+ bias = module .bias is not None ,
518+ activation_scheme = quantization_config .activation_scheme ,
519+ block_size = quantization_config .weight_block_size ,
520+ ** module_kwargs ,
521+ )
522+ if new_module is not None :
523+ model .set_submodule (module_name , new_module )
524+ has_been_replaced = True
588525
589526 if not has_been_replaced :
590527 logger .warning (
591528 "You are loading your model using fp8 but no linear modules were found in your model."
592529 " Please double check your model architecture."
593530 )
594-
595531 return model
596532
597533
@@ -606,7 +542,7 @@ def __init__(self, hf_quantizer):
606542 def convert (self , input_dict : torch .Tensor , ** kwargs ) -> dict [str , torch .Tensor ]:
607543 # Unpack single key/value (value may be wrapped in a list)
608544 target_keys , value = tuple (input_dict .items ())[0 ]
609- value = value [0 ] if isinstance ( value , list ) else value
545+ value = value [0 ]
610546
611547 # Resolve block size (support dict-like or attr-like quant_config)
612548 block_size = None
@@ -681,24 +617,15 @@ def __init__(self, hf_quantizer):
681617 def convert (
682618 self ,
683619 input_dict : dict [str , torch .Tensor ],
684- model : Optional [torch .nn .Module ] = None ,
685620 full_layer_name : str | None = None ,
686- missing_keys = None ,
687621 ** kwargs ,
688622 ) -> dict [str , torch .Tensor ]:
689623 if len (input_dict ) < 2 :
690- # in case of no scales, the weights are not quantized, so we return the weights as is
691- return {
692- full_layer_name : input_dict ["weight$" ][0 ]
693- if isinstance (input_dict ["weight$" ], list )
694- else input_dict ["weight$" ]
695- }
696- quantized = input_dict ["weight$" ][0 ] if isinstance (input_dict ["weight$" ], list ) else input_dict ["weight$" ]
697- scales = (
698- input_dict ["weight_scale_inv" ][0 ]
699- if isinstance (input_dict ["weight_scale_inv" ], list )
700- else input_dict ["weight_scale_inv" ]
701- )
624+ # case where we only got weights, need to check for "weight$"
625+ return {full_layer_name : input_dict ["weight$" ]}
626+
627+ quantized = input_dict ["weight$" ][0 ]
628+ scales = input_dict ["weight_scale_inv" ][0 ]
702629
703630 rows , cols = quantized .shape [- 2 :]
704631 block_size = self .hf_quantizer .quantization_config .weight_block_size
0 commit comments