Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class AscendConfig:

def __init__(self, vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}

self.mix_placement = additional_config.get("mix_placement", False)
xlite_graph_config = additional_config.get("xlite_graph_config", {})
self.xlite_graph_config = XliteGraphConfig(xlite_graph_config,
vllm_config)
Expand Down Expand Up @@ -193,12 +193,12 @@ class AscendCompilationConfig:
def __init__(self, fuse_norm_quant: bool = True, **kwargs):
"""
Initialize the configuration.

Args:
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations.
Default: True

**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fuse_norm_quant = fuse_norm_quant
Expand Down
16 changes: 16 additions & 0 deletions vllm_ascend/ops/fused_moe/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def select_experts(hidden_states: torch.Tensor,
routed_scaling_factor=1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
mix_placement: Optional[bool] = False,
num_logical_experts: int = -1,
global_num_experts: int = -1):
"""
Fused experts with select experts.
Expand Down Expand Up @@ -95,6 +97,20 @@ def select_experts(hidden_states: torch.Tensor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
if mix_placement:
shared_expert_routing_fator = 0.4
pad_shared_expert_ids = torch.full((topk_ids.shape[0], 1),
num_logical_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)

pad_shared_expert_weights = torch.full((topk_weights.shape[0], 1),
shared_expert_routing_fator,
dtype=topk_weights.dtype,
device=topk_weights.device)
topk_ids = torch.cat([topk_ids, pad_shared_expert_ids], dim=1)
topk_weights = torch.cat([topk_weights, pad_shared_expert_weights],
dim=1)
return topk_weights, topk_ids


Expand Down
46 changes: 31 additions & 15 deletions vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ def __init__(self, *args, **kwargs):
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
ascend_config = get_ascend_config()
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.ascend_config = get_ascend_config()
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
self.expert_map_path = self.ascend_config.expert_map_path
self.global_redundant_expert_num = self.ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
vllm_config = get_current_vllm_config()
Expand All @@ -177,6 +177,8 @@ def __init__(self, *args, **kwargs):
# TODO: Temporary flag to indicate if static EPLB is enabled. This is a
# workaround to bypass a quantization check that fails with float weights.
init_eplb_enable = False
num_experts += 1 if getattr(self.ascend_config, "mix_placement",
False) else 0
# static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path,
Expand Down Expand Up @@ -243,7 +245,7 @@ def __init__(self, *args, **kwargs):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)

self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp

setup_moe_comm_method(self.moe_config)
self.quant_type = self._get_quant_type()
Expand Down Expand Up @@ -401,8 +403,8 @@ def __init__(
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
self.ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = self.ascend_config.multistream_overlap_shared_expert
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
Expand Down Expand Up @@ -430,11 +432,19 @@ def forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out, fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
if self._shared_experts is None:
fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
shared_out = None
else:
shared_out, fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out

def forward_impl(self, hidden_states: torch.Tensor,
Expand All @@ -448,7 +458,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
# Use a separate stream to run shared experts.
# Note that currently we only support calculations in separate streams with aclgraph.
# Communication operations in another stream might cause unknown errors.
shared_out = self._shared_experts(hidden_states)
if self._shared_experts is None:
shared_out = None
else:
shared_out = self._shared_experts(hidden_states)

fused_output = AscendFusedMoE.forward_impl(
self,
Expand All @@ -463,6 +476,9 @@ def forward_impl(self, hidden_states: torch.Tensor,
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
and not shared_expert_dp_enabled():
and not shared_expert_dp_enabled() and shared_out is not None:
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_output
if shared_out is None:
return fused_output
else:
return shared_out, fused_output
1 change: 1 addition & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,4 @@
# Future Plan:
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
#
from vllm_ascend.patch.worker import patch_deepseekv3

Check failure on line 256 in vllm_ascend/patch/__init__.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Ruff (F401)

vllm_ascend/patch/__init__.py:256:38: F401 `vllm_ascend.patch.worker.patch_deepseekv3` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
Loading
Loading