77
88import torch
99from torch .nn import Module
10- from torch .nn .parameter import Parameter
1110
1211import vllm .envs as envs
1312import vllm .model_executor .layers .fused_moe .modular_kernel as mk
@@ -525,13 +524,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
525524 weight = weight .t ()
526525
527526 # Update layer with new values.
528- layer .weight = Parameter (weight .data , requires_grad = False )
529- layer .weight_scale = Parameter (weight_scale .data , requires_grad = False )
530- layer .input_scale = (
531- Parameter (input_scale , requires_grad = False )
532- if input_scale is not None
533- else None
534- )
527+ layer .weight .copy_ (weight .data )
528+ layer .weight_scale .copy_ (weight_scale .data )
529+ if input_scale is not None :
530+ layer .input_scale .copy_ (input_scale )
535531
536532 if self .use_marlin :
537533 prepare_fp8_layer_for_marlin (layer , size_k_first )
@@ -827,22 +823,18 @@ def process_weights_after_loading(self, layer: Module) -> None:
827823 w2_weight_scale_inv = layer .w2_weight_scale_inv
828824
829825 # torch.compile() cannot use Parameter subclasses.
830- layer .w13_weight = Parameter (w13_weight , requires_grad = False )
831- layer .w13_weight_scale_inv = Parameter (
832- w13_weight_scale_inv , requires_grad = False
833- )
834- layer .w2_weight = Parameter (w2_weight , requires_grad = False )
835- layer .w2_weight_scale_inv = Parameter (
836- w2_weight_scale_inv , requires_grad = False
837- )
826+ layer .w13_weight .copy_ (w13_weight )
827+ layer .w13_weight_scale_inv .copy_ (w13_weight_scale_inv )
828+ layer .w2_weight .copy_ (w2_weight )
829+ layer .w2_weight_scale_inv .copy_ (w2_weight_scale_inv )
838830 if self .rocm_aiter_moe_enabled :
839831 # reshaping weights is required for aiter moe kernel.
840832 shuffled_w13 , shuffled_w2 = rocm_aiter_ops .shuffle_weights (
841833 layer .w13_weight .data , layer .w2_weight .data
842834 )
843835
844- layer .w13_weight = torch . nn . Parameter (shuffled_w13 , requires_grad = False )
845- layer .w2_weight = torch . nn . Parameter (shuffled_w2 , requires_grad = False )
836+ layer .w13_weight . copy_ (shuffled_w13 )
837+ layer .w2_weight . copy_ (shuffled_w2 )
846838
847839 # DeepGemm scales need to be transposed and aligned. We try to do
848840 # it ahead of time for performance reasons.
@@ -864,7 +856,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
864856
865857 # Re-initialize w13_scale because we directly quantize
866858 # merged w13 weights and generate a single scaling factor.
867- layer .w13_weight_scale = torch . nn . Parameter (
859+ layer .w13_weight_scale . copy_ (
868860 torch .ones (
869861 layer .local_num_experts ,
870862 dtype = torch .float32 ,
@@ -879,16 +871,16 @@ def process_weights_after_loading(self, layer: Module) -> None:
879871 w2_weight [expert , :, :], layer .w2_weight_scale [expert ] = (
880872 ops .scaled_fp8_quant (layer .w2_weight .data [expert , :, :])
881873 )
882- layer .w13_weight = torch . nn . Parameter (w13_weight , requires_grad = False )
883- layer .w2_weight = torch . nn . Parameter (w2_weight , requires_grad = False )
874+ layer .w13_weight . copy_ (w13_weight )
875+ layer .w2_weight . copy_ (w2_weight )
884876 if self .rocm_aiter_moe_enabled :
885877 # reshaping weights is required for aiter moe kernel.
886878 shuffled_w13 , shuffled_w2 = rocm_aiter_ops .shuffle_weights (
887879 layer .w13_weight , layer .w2_weight
888880 )
889881
890- layer .w13_weight = torch . nn . Parameter (shuffled_w13 , requires_grad = False )
891- layer .w2_weight = torch . nn . Parameter (shuffled_w2 , requires_grad = False )
882+ layer .w13_weight . copy_ (shuffled_w13 )
883+ layer .w2_weight . copy_ (shuffled_w2 )
892884 # If checkpoint is fp8, we need to handle that the
893885 # MoE kernels require single activation scale and single weight
894886 # scale for w13 per expert.
@@ -909,12 +901,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
909901 "fp8 MoE layer. Using the maximum across experts "
910902 "for each layer."
911903 )
912- layer .w13_input_scale = torch .nn .Parameter (
913- layer .w13_input_scale .max (), requires_grad = False
914- )
915- layer .w2_input_scale = torch .nn .Parameter (
916- layer .w2_input_scale .max (), requires_grad = False
917- )
904+ layer .w13_input_scale .copy_ (layer .w13_input_scale .max ())
905+ layer .w2_input_scale .copy_ (layer .w2_input_scale .max ())
918906 if current_platform .is_fp8_fnuz ():
919907 # Normalize the weights and scales
920908 w13_weight , w13_weight_scale , w13_input_scale = (
@@ -928,22 +916,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
928916 )
929917 )
930918 # Reset the parameter
931- layer .w13_weight = torch .nn .Parameter (w13_weight , requires_grad = False )
932- layer .w13_weight_scale = torch .nn .Parameter (
933- w13_weight_scale , requires_grad = False
934- )
919+ layer .w13_weight .copy_ (w13_weight )
920+ layer .w13_weight_scale .copy_ (w13_weight_scale )
935921 if w13_input_scale is not None :
936- layer .w13_input_scale = torch .nn .Parameter (
937- w13_input_scale , requires_grad = False
938- )
939- layer .w2_weight = torch .nn .Parameter (w2_weight , requires_grad = False )
940- layer .w2_weight_scale = torch .nn .Parameter (
941- w2_weight_scale , requires_grad = False
942- )
922+ layer .w13_input_scale .copy_ (w13_input_scale )
923+ layer .w2_weight .copy_ (w2_weight )
924+ layer .w2_weight_scale .copy_ (w2_weight_scale )
943925 if w2_input_scale is not None :
944- layer .w2_input_scale = torch .nn .Parameter (
945- w2_input_scale , requires_grad = False
946- )
926+ layer .w2_input_scale .copy_ (w2_input_scale )
947927
948928 # Fp8 moe kernel needs single weight scale for w13 per expert.
949929 # We take the max then dequant and requant each expert.
@@ -967,12 +947,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
967947 layer .w13_weight , layer .w2_weight
968948 )
969949
970- layer .w13_weight = torch . nn . Parameter (shuffled_w13 , requires_grad = False )
971- layer .w2_weight = torch . nn . Parameter (shuffled_w2 , requires_grad = False )
950+ layer .w13_weight . copy_ (shuffled_w13 )
951+ layer .w2_weight . copy_ (shuffled_w2 )
972952
973- layer .w13_weight_scale = torch .nn .Parameter (
974- max_w13_scales , requires_grad = False
975- )
953+ layer .w13_weight_scale .copy_ (max_w13_scales )
976954
977955 if self .flashinfer_moe_backend is not None :
978956 # NOTE: weights have to be swapped since the activation is
0 commit comments