Skip to content

Commit 749c91c

Browse files
committed
register weight scale in create params, still issue with reloading from disk
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 1576409 commit 749c91c

File tree

1 file changed

+18
-8
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+18
-8
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,18 @@ def create_weights(
483483
else:
484484
layer.register_parameter("input_scale", None)
485485

486+
# create per-tensor qparams populated by process_weights_after_loading
487+
else:
488+
scale = create_fp8_scale_parameter(
489+
PerTensorScaleParameter,
490+
output_partition_sizes,
491+
input_size_per_partition,
492+
None,
493+
weight_loader,
494+
)
495+
set_weight_attrs(scale, {"scale_type": "weight_scale"})
496+
layer.register_parameter("weight_scale", scale)
497+
486498
def process_weights_after_loading(self, layer: Module) -> None:
487499
size_k_first = True
488500
input_scale = None
@@ -494,8 +506,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
494506
weight, weight_scale = process_fp8_weight_block_strategy(
495507
layer.weight, layer.weight_scale_inv
496508
)
497-
# Delete the weight_scale_inv parameter to avoid confusion
498-
# with the weight_scale parameter
509+
# Rename weight_scale_inv parameter for consistency
510+
layer.weight_scale = layer.weight_scale_inv
499511
del layer.weight_scale_inv
500512

501513
# If checkpoint not serialized fp8, quantize the weights.
@@ -755,12 +767,10 @@ def create_weights(
755767
if self.block_quant
756768
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
757769
)
758-
# If loading fp8 checkpoint, pass the weight loaders.
759-
# If loading an fp16 checkpoint, do not (we will quantize in
760-
# process_weights_after_loading()
761-
if self.quant_config.is_checkpoint_fp8_serialized:
762-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
763-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
770+
771+
# add weight loaders to support loading (and reloading)
772+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
773+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
764774

765775
# INPUT_SCALES
766776
if self.quant_config.activation_scheme == "static":

0 commit comments

Comments
 (0)