Skip to content

Commit a5025a2

Browse files
authored
Fix BMM style MoE export in fp8_pc_pt recipe (#515)
## What does this PR do? **Type of change:** Bug fix **Overview:** The Llama-4-Scout-17B-16E-Instruct model uses Llama4TextExperts, which stores expert weights in a BMM (batch matrix multiply) layout: (num_experts, input_dim, output_dim). This is different from standard MoE models. The FP8_PC_PT (FP8 per-channel per-token) quantization code didn't handle this layout properly, causing shape mismatches. ## Usage <!-- You can potentially add a usage example below. --> ```python python3 hf_ptq.py --pyt_ckpt_path /home/scratch.omniml_data_2/models/Llama-4-Scout-17B-16E-Instruct --qformat fp8_pc_pt --export_path /home/scratch.omniml_data_2/zhiyuc/checkpoints/llama4-scout-fp8_pc_pt --trust_remote_code ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 93f5bbf commit a5025a2

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,36 @@ def to_quantized_weight(
779779
)[0]._quantized_data
780780

781781
if quantization == QUANTIZATION_FP8_PC_PT:
782-
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
782+
if weight.dim() == 3:
783+
# Handle different scale tensor shapes
784+
if weights_scaling_factor.dim() == 1:
785+
# Per-expert scaling only: (num_experts,) -> (num_experts, 1, 1)
786+
return (weight / weights_scaling_factor[:, None, None]).to(torch.float8_e4m3fn)
787+
elif weights_scaling_factor.dim() == 2:
788+
# Per-channel scaling: check which dimension matches
789+
if weights_scaling_factor.shape[0] != weight.shape[0]:
790+
raise ValueError(
791+
f"First dimension (num_experts) mismatch for FP8_PC_PT quantization. "
792+
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}"
793+
)
794+
if weight.shape[-1] == weight.shape[-2]:
795+
raise ValueError(
796+
f"Ambiguous scaling dimension for FP8_PC_PT quantization with square weight matrix. "
797+
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}. "
798+
f"Cannot determine if scaling should be applied to input_dim or output_dim."
799+
)
800+
if weights_scaling_factor.shape[-1] == weight.shape[-1]:
801+
# (num_experts, input_dim) -> (num_experts, 1, input_dim), BMM-style
802+
return (weight / weights_scaling_factor.unsqueeze(-2)).to(torch.float8_e4m3fn)
803+
elif weights_scaling_factor.shape[-1] == weight.shape[-2]:
804+
# (num_experts, output_dim) -> (num_experts, output_dim, 1), Standard MoE case
805+
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
806+
else:
807+
raise ValueError(
808+
f"Cannot determine correct unsqueeze dimension for FP8_PC_PT quantization. "
809+
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}"
810+
)
811+
return (weight / weights_scaling_factor[:, None]).to(torch.float8_e4m3fn)
783812

784813
if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]:
785814
return pack_int4_in_uint8(weight, weights_scaling_factor)

modelopt/torch/export/unified_export_hf.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
KV_CACHE_NVFP4_AFFINE,
5151
QUANTIZATION_FP8,
5252
QUANTIZATION_FP8_PB_REAL,
53+
QUANTIZATION_FP8_PC_PT,
5354
QUANTIZATION_NONE,
5455
QUANTIZATION_NVFP4,
5556
QUANTIZATION_NVFP4_AWQ,
@@ -327,13 +328,15 @@ def _export_quantized_weight(
327328
weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None)
328329

329330
# Transpose weight for bmm-style expert quantization (llama4, gpt-oss)
331+
# Check if this is a BMM-style expert weight that needs transposition
332+
is_bmm_expert_weight = weight.dim() == 3 and any(
333+
expert_type in type(sub_module).__name__
334+
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
335+
)
336+
330337
if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
331338
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
332339
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
333-
is_bmm_expert_weight = weight.dim() == 3 and any(
334-
expert_type in type(sub_module).__name__
335-
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
336-
)
337340
weight, _ = maybe_transpose_expert_weight_dimensions(
338341
weight, is_bmm_expert_weight=is_bmm_expert_weight
339342
)
@@ -354,6 +357,24 @@ def _export_quantized_weight(
354357
quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions(
355358
quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight
356359
)
360+
elif quantization_format == QUANTIZATION_FP8_PC_PT and is_bmm_expert_weight:
361+
# For FP8_PC_PT with BMM-style experts, transpose only the weight (not weight_scale)
362+
weight, _ = maybe_transpose_expert_weight_dimensions(
363+
weight, is_bmm_expert_weight=is_bmm_expert_weight
364+
)
365+
366+
quantized_weight = to_quantized_weight(
367+
weight.to(dtype),
368+
weight_scale,
369+
quantization_format,
370+
weight_scale_2,
371+
block_size,
372+
)
373+
374+
# Transpose back to original BMM format
375+
quantized_weight, _ = maybe_transpose_expert_weight_dimensions(
376+
quantized_weight, is_bmm_expert_weight=is_bmm_expert_weight
377+
)
357378
else:
358379
quantized_weight = to_quantized_weight(
359380
weight.to(dtype),

0 commit comments

Comments
 (0)