@@ -483,6 +483,18 @@ def create_weights(
483483 else :
484484 layer .register_parameter ("input_scale" , None )
485485
486+ # create per-tensor qparams populated by process_weights_after_loading
487+ else :
488+ scale = create_fp8_scale_parameter (
489+ PerTensorScaleParameter ,
490+ output_partition_sizes ,
491+ input_size_per_partition ,
492+ None ,
493+ weight_loader ,
494+ )
495+ set_weight_attrs (scale , {"scale_type" : "weight_scale" })
496+ layer .register_parameter ("weight_scale" , scale )
497+
486498 def process_weights_after_loading (self , layer : Module ) -> None :
487499 size_k_first = True
488500 input_scale = None
@@ -494,8 +506,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
494506 weight , weight_scale = process_fp8_weight_block_strategy (
495507 layer .weight , layer .weight_scale_inv
496508 )
497- # Delete the weight_scale_inv parameter to avoid confusion
498- # with the weight_scale parameter
509+ # Rename weight_scale_inv parameter for consistency
510+ layer . weight_scale = layer . weight_scale_inv
499511 del layer .weight_scale_inv
500512
501513 # If checkpoint not serialized fp8, quantize the weights.
@@ -755,12 +767,10 @@ def create_weights(
755767 if self .block_quant
756768 else {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value }
757769 )
758- # If loading fp8 checkpoint, pass the weight loaders.
759- # If loading an fp16 checkpoint, do not (we will quantize in
760- # process_weights_after_loading()
761- if self .quant_config .is_checkpoint_fp8_serialized :
762- set_weight_attrs (w13_weight_scale , extra_weight_attrs )
763- set_weight_attrs (w2_weight_scale , extra_weight_attrs )
770+
771+ # add weight loaders to support loading (and reloading)
772+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
773+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
764774
765775 # INPUT_SCALES
766776 if self .quant_config .activation_scheme == "static" :
0 commit comments