|
16 | 16 | # Adapted from vllm/tests/kernels/test_moe.py |
17 | 17 |
|
18 | 18 | import os |
19 | | -from typing import Any, Callable, Optional, Tuple, Union |
| 19 | +from typing import Any, Callable, Dict, Optional, Tuple, Union |
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | import torch.distributed as dist |
|
45 | 45 | from vllm_ascend.distributed.parallel_state import get_mc2_group |
46 | 46 | from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map |
47 | 47 | from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer |
48 | | -from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod |
| 48 | +from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod, |
| 49 | + AscendQuantConfig) |
| 50 | +from vllm_ascend.quantization.utils import get_quant_method |
49 | 51 | from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding |
50 | 52 | from vllm_ascend.torchair.utils import (get_all_reduce_merge_state, |
51 | 53 | get_rm_router_logits_state, |
@@ -936,6 +938,15 @@ def apply( |
936 | 938 | ep_group=get_ep_group()) |
937 | 939 |
|
938 | 940 |
|
| 941 | +class TorchairAscendFusedMoEMethod(AscendFusedMoEMethod): |
| 942 | + |
| 943 | + def __init__(self, quant_config: AscendQuantConfig, prefix: str, |
| 944 | + packed_modules_mapping: Dict[str, Any]): |
| 945 | + self.quant_method = get_quant_method(quant_config.quant_description, |
| 946 | + prefix, "moe", |
| 947 | + packed_modules_mapping) |
| 948 | + |
| 949 | + |
939 | 950 | class TorchairAscendFusedMoE(FusedMoE): |
940 | 951 |
|
941 | 952 | # The moe_counter parameter is required during the initialization of EPLB |
@@ -1115,7 +1126,7 @@ def __init__( |
1115 | 1126 | self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( |
1116 | 1127 | self.moe) |
1117 | 1128 | else: |
1118 | | - self.quant_method = AscendFusedMoEMethod( |
| 1129 | + self.quant_method = TorchairAscendFusedMoEMethod( |
1119 | 1130 | quant_config, prefix, quant_config.packed_modules_mapping) |
1120 | 1131 |
|
1121 | 1132 | assert self.quant_method is not None |
|
0 commit comments