Skip to content

Commit dbf10e5

Browse files
committed
no new parameters
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 684f254 commit dbf10e5

File tree

2 files changed

+33
-48
lines changed

2 files changed

+33
-48
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 26 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
from torch.nn import Module
10-
from torch.nn.parameter import Parameter
1110

1211
import vllm.envs as envs
1312
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -525,13 +524,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
525524
weight = weight.t()
526525

527526
# Update layer with new values.
528-
layer.weight = Parameter(weight.data, requires_grad=False)
529-
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
530-
layer.input_scale = (
531-
Parameter(input_scale, requires_grad=False)
532-
if input_scale is not None
533-
else None
534-
)
527+
layer.weight.copy_(weight.data)
528+
layer.weight_scale.copy_(weight_scale.data)
529+
if input_scale is not None:
530+
layer.input_scale.copy_(input_scale)
535531

536532
if self.use_marlin:
537533
prepare_fp8_layer_for_marlin(layer, size_k_first)
@@ -827,22 +823,18 @@ def process_weights_after_loading(self, layer: Module) -> None:
827823
w2_weight_scale_inv = layer.w2_weight_scale_inv
828824

829825
# torch.compile() cannot use Parameter subclasses.
830-
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
831-
layer.w13_weight_scale_inv = Parameter(
832-
w13_weight_scale_inv, requires_grad=False
833-
)
834-
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
835-
layer.w2_weight_scale_inv = Parameter(
836-
w2_weight_scale_inv, requires_grad=False
837-
)
826+
layer.w13_weight.copy_(w13_weight)
827+
layer.w13_weight_scale_inv.copy_(w13_weight_scale_inv)
828+
layer.w2_weight.copy_(w2_weight)
829+
layer.w2_weight_scale_inv.copy_(w2_weight_scale_inv)
838830
if self.rocm_aiter_moe_enabled:
839831
# reshaping weights is required for aiter moe kernel.
840832
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
841833
layer.w13_weight.data, layer.w2_weight.data
842834
)
843835

844-
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
845-
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
836+
layer.w13_weight.copy_(shuffled_w13)
837+
layer.w2_weight.copy_(shuffled_w2)
846838

847839
# DeepGemm scales need to be transposed and aligned. We try to do
848840
# it ahead of time for performance reasons.
@@ -864,7 +856,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
864856

865857
# Re-initialize w13_scale because we directly quantize
866858
# merged w13 weights and generate a single scaling factor.
867-
layer.w13_weight_scale = torch.nn.Parameter(
859+
layer.w13_weight_scale.copy_(
868860
torch.ones(
869861
layer.local_num_experts,
870862
dtype=torch.float32,
@@ -879,16 +871,16 @@ def process_weights_after_loading(self, layer: Module) -> None:
879871
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
880872
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
881873
)
882-
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
883-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
874+
layer.w13_weight.copy_(w13_weight)
875+
layer.w2_weight.copy_(w2_weight)
884876
if self.rocm_aiter_moe_enabled:
885877
# reshaping weights is required for aiter moe kernel.
886878
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
887879
layer.w13_weight, layer.w2_weight
888880
)
889881

890-
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
891-
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
882+
layer.w13_weight.copy_(shuffled_w13)
883+
layer.w2_weight.copy_(shuffled_w2)
892884
# If checkpoint is fp8, we need to handle that the
893885
# MoE kernels require single activation scale and single weight
894886
# scale for w13 per expert.
@@ -909,12 +901,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
909901
"fp8 MoE layer. Using the maximum across experts "
910902
"for each layer."
911903
)
912-
layer.w13_input_scale = torch.nn.Parameter(
913-
layer.w13_input_scale.max(), requires_grad=False
914-
)
915-
layer.w2_input_scale = torch.nn.Parameter(
916-
layer.w2_input_scale.max(), requires_grad=False
917-
)
904+
layer.w13_input_scale.copy_(layer.w13_input_scale.max())
905+
layer.w2_input_scale.copy_(layer.w2_input_scale.max())
918906
if current_platform.is_fp8_fnuz():
919907
# Normalize the weights and scales
920908
w13_weight, w13_weight_scale, w13_input_scale = (
@@ -928,22 +916,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
928916
)
929917
)
930918
# Reset the parameter
931-
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
932-
layer.w13_weight_scale = torch.nn.Parameter(
933-
w13_weight_scale, requires_grad=False
934-
)
919+
layer.w13_weight.copy_(w13_weight)
920+
layer.w13_weight_scale.copy_(w13_weight_scale)
935921
if w13_input_scale is not None:
936-
layer.w13_input_scale = torch.nn.Parameter(
937-
w13_input_scale, requires_grad=False
938-
)
939-
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
940-
layer.w2_weight_scale = torch.nn.Parameter(
941-
w2_weight_scale, requires_grad=False
942-
)
922+
layer.w13_input_scale.copy_(w13_input_scale)
923+
layer.w2_weight.copy_(w2_weight)
924+
layer.w2_weight_scale.copy_(w2_weight_scale)
943925
if w2_input_scale is not None:
944-
layer.w2_input_scale = torch.nn.Parameter(
945-
w2_input_scale, requires_grad=False
946-
)
926+
layer.w2_input_scale.copy_(w2_input_scale)
947927

948928
# Fp8 moe kernel needs single weight scale for w13 per expert.
949929
# We take the max then dequant and requant each expert.
@@ -967,12 +947,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
967947
layer.w13_weight, layer.w2_weight
968948
)
969949

970-
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
971-
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
950+
layer.w13_weight.copy_(shuffled_w13)
951+
layer.w2_weight.copy_(shuffled_w2)
972952

973-
layer.w13_weight_scale = torch.nn.Parameter(
974-
max_w13_scales, requires_grad=False
975-
)
953+
layer.w13_weight_scale.copy_(max_w13_scales)
976954

977955
if self.flashinfer_moe_backend is not None:
978956
# NOTE: weights have to be swapped since the activation is

vllm/model_executor/layers/quantization/kv_cache.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor:
4545
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
4646

4747
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
48+
# skip if there are no weights to process (for examplle, weight reloading)
49+
if not hasattr(layer, "q_scale"):
50+
assert not hasattr(layer, "k_scale")
51+
assert not hasattr(layer, "v_scale")
52+
assert not hasattr(layer, "prob_scale")
53+
return
54+
4855
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
4956
# regardless whether the kv-scale is available in the checkpoint.
5057
# No need to process kv scales after loading if we are going to

0 commit comments

Comments
 (0)