Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,10 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
quant_config: The Ascend quantization config.
"""

def __init__(self,
quant_config: AscendQuantConfig,
prefix: str,
packed_modules_mapping: Dict[str, Any],
layer: torch.nn.Module = None):
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str,
Any], layer: torch.nn.Module):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest formatting it.

super().__init__(layer.moe_config)
self.quant_method = get_quant_method(quant_config.quant_description,
prefix,
"moe",
Expand Down
17 changes: 14 additions & 3 deletions vllm_ascend/torchair/ops/torchair_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Adapted from vllm/tests/kernels/test_moe.py

import os
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -45,7 +45,9 @@
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod,
AscendQuantConfig)
from vllm_ascend.quantization.utils import get_quant_method
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import (get_all_reduce_merge_state,
get_rm_router_logits_state,
Expand Down Expand Up @@ -936,6 +938,15 @@ def apply(
ep_group=get_ep_group())


class TorchairAscendFusedMoEMethod(AscendFusedMoEMethod):

def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any]):
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "moe",
packed_modules_mapping)


class TorchairAscendFusedMoE(FusedMoE):

# The moe_counter parameter is required during the initialization of EPLB
Expand Down Expand Up @@ -1115,7 +1126,7 @@ def __init__(
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
self.moe)
else:
self.quant_method = AscendFusedMoEMethod(
self.quant_method = TorchairAscendFusedMoEMethod(
quant_config, prefix, quant_config.packed_modules_mapping)

assert self.quant_method is not None
Expand Down
Loading