Skip to content

Commit 7e39bb5

Browse files
committed
[main][bugfix] bugfix for qwen3 moe quantization
Signed-off-by: Wang Kunpeng <[email protected]>
1 parent 6d7db0f commit 7e39bb5

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

vllm_ascend/quantization/quant_config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,9 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
408408
quant_config: The Ascend quantization config.
409409
"""
410410

411-
def __init__(self,
412-
quant_config: AscendQuantConfig,
413-
prefix: str,
414-
packed_modules_mapping: Dict[str, Any],
415-
layer: torch.nn.Module = None):
411+
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
412+
packed_modules_mapping: Dict[str,
413+
Any], layer: torch.nn.Module):
416414
super().__init__(layer.moe_config)
417415
self.quant_method = get_quant_method(quant_config.quant_description,
418416
prefix,

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

1818
import os
19-
from typing import Any, Callable, Optional, Tuple, Union
19+
from typing import Any, Callable, Dict, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.distributed as dist
@@ -45,7 +45,9 @@
4545
from vllm_ascend.distributed.parallel_state import get_mc2_group
4646
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
4747
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
48-
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
48+
from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod,
49+
AscendQuantConfig)
50+
from vllm_ascend.quantization.utils import get_quant_method
4951
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
5052
from vllm_ascend.torchair.utils import (get_all_reduce_merge_state,
5153
get_rm_router_logits_state,
@@ -936,6 +938,15 @@ def apply(
936938
ep_group=get_ep_group())
937939

938940

941+
class TorchairAscendFusedMoEMethod(AscendFusedMoEMethod):
942+
943+
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
944+
packed_modules_mapping: Dict[str, Any]):
945+
self.quant_method = get_quant_method(quant_config.quant_description,
946+
prefix, "moe",
947+
packed_modules_mapping)
948+
949+
939950
class TorchairAscendFusedMoE(FusedMoE):
940951

941952
# The moe_counter parameter is required during the initialization of EPLB
@@ -1115,7 +1126,7 @@ def __init__(
11151126
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
11161127
self.moe)
11171128
else:
1118-
self.quant_method = AscendFusedMoEMethod(
1129+
self.quant_method = TorchairAscendFusedMoEMethod(
11191130
quant_config, prefix, quant_config.packed_modules_mapping)
11201131

11211132
assert self.quant_method is not None

0 commit comments

Comments
 (0)