Skip to content

Commit bb09a30

Browse files
authored
[Quantization] fix dequant when block size is none & static quantization (#42545)
* fix * style
1 parent ac0769c commit bb09a30

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/transformers/integrations/finegrained_fp8.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def convert(
686686
missing_keys=None,
687687
**kwargs,
688688
) -> dict[str, torch.Tensor]:
689-
if len(input_dict) != 2:
689+
if len(input_dict) < 2:
690690
# in case of no scales, the weights are not quantized, so we return the weights as is
691691
return {
692692
full_layer_name: input_dict["weight$"][0]
@@ -702,15 +702,18 @@ def convert(
702702

703703
rows, cols = quantized.shape[-2:]
704704
block_size = self.hf_quantizer.quantization_config.weight_block_size
705+
if block_size is None:
706+
block_size = (quantized.shape[-2], quantized.shape[-1])
705707

706708
block_m, block_n = block_size
709+
707710
if rows % block_m != 0 or cols % block_n != 0:
708711
raise ValueError(
709712
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
710713
)
711-
714+
quantized = quantized.to(scales.dtype)
712715
reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
713-
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
716+
expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n)
714717
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
715718
dequantized = reshaped * expanded_scales
716719

src/transformers/quantizers/quantizer_finegrained_fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ def get_weight_conversions(self):
246246
if self.pre_quantized and self.quantization_config.dequantize:
247247
return [
248248
# either use the dollar sign, or permute the source patterns to start matching against the scales first
249+
# We also collect the activation scales, they will not be used
249250
WeightConverter(
250-
source_patterns=["weight$", "weight_scale_inv"],
251+
source_patterns=["weight$", "weight_scale_inv", "activation_scale"],
251252
target_patterns="weight",
252253
operations=[Fp8Dequantize(self)],
253254
)

0 commit comments

Comments
 (0)