Skip to content

Commit 60c5bb3

Browse files
top9
2 parents d628411 + 29cbae3 commit 60c5bb3

File tree

6 files changed

+368
-18
lines changed

6 files changed

+368
-18
lines changed

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class AscendConfig:
6666

6767
def __init__(self, vllm_config):
6868
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
69+
self.mix_placement = additional_config.get("mix_placement",False)
6970
torchair_graph_config = additional_config.get("torchair_graph_config",
7071
{})
7172

vllm_ascend/ops/fused_moe/experts_selector.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def select_experts(hidden_states: torch.Tensor,
3333
routed_scaling_factor=1.0,
3434
e_score_correction_bias: Optional[torch.Tensor] = None,
3535
indices_type: Optional[torch.dtype] = None,
36+
mix_placement: Optional[bool] = False,
37+
num_logical_experts: int = -1,
3638
global_num_experts: int = -1):
3739
"""
3840
Fused experts with select experts.
@@ -95,6 +97,19 @@ def select_experts(hidden_states: torch.Tensor,
9597
e_score_correction_bias=e_score_correction_bias,
9698
global_num_experts=global_num_experts,
9799
)
100+
if mix_placement:
101+
pad_shared_expert_ids = torch.full((topk_ids.shape[0], 1),
102+
num_logical_experts,
103+
dtype=topk_ids.dtype,
104+
device=topk_ids.device)
105+
106+
pad_shared_expert_weights = torch.full((topk_weights.shape[0], 1),
107+
0.4,
108+
dtype=topk_weights.dtype,
109+
device=topk_weights.device)
110+
topk_ids = torch.cat([topk_ids, pad_shared_expert_ids], dim=1)
111+
topk_weights = torch.cat([topk_weights, pad_shared_expert_weights],
112+
dim=1)
98113
return topk_weights, topk_ids
99114

100115

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,10 @@ def __init__(self, *args, **kwargs):
172172
self.moe_config.dp_group = get_dp_group()
173173
self.moe_config.ep_group = get_ep_group()
174174
self.moe_config.mc2_group = get_mc2_group()
175-
ascend_config = get_ascend_config()
176-
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
177-
self.expert_map_path = ascend_config.expert_map_path
178-
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
175+
self.ascend_config = get_ascend_config()
176+
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
177+
self.expert_map_path = self.ascend_config.expert_map_path
178+
self.global_redundant_expert_num = self.ascend_config.init_redundancy_expert
179179
self.global_num_experts = num_experts + self.global_redundant_expert_num
180180
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
181181
vllm_config = get_current_vllm_config()
@@ -195,8 +195,8 @@ def __init__(self, *args, **kwargs):
195195
self.expert_load_balancer = ExpertLoadBalancer(
196196
self.expert_map_path, num_experts)
197197
self.expert_load_balancer.check_expert_map_tensor()
198-
self.global_redundant_expert_num = (
199-
self.expert_load_balancer.get_global_redundant_expert_num())
198+
# self.global_redundant_expert_num = (
199+
# self.expert_load_balancer.get_global_redundant_expert_num())
200200
self.global_num_experts = num_experts + self.global_redundant_expert_num
201201
try:
202202
self.local_num_experts, self.expert_map = (
@@ -254,7 +254,7 @@ def __init__(self, *args, **kwargs):
254254
moe_quant_params["intermediate_size_full"] = intermediate_size
255255
self.quant_method.create_weights(layer=self, **moe_quant_params)
256256

257-
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
257+
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
258258

259259
setup_moe_comm_method(self.moe_config)
260260
self.quant_type = self._get_quant_type()
@@ -460,8 +460,8 @@ def __init__(
460460
self._shared_experts = shared_experts
461461
self.use_overlapped = use_overlapped
462462
self.shared_expert_stream = None
463-
ascend_config = get_ascend_config()
464-
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
463+
self.ascend_config = get_ascend_config()
464+
self.multistream_overlap_shared_expert = self.ascend_config.multistream_overlap_shared_expert
465465
if enable_sp():
466466
logger.info_once(
467467
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -489,11 +489,19 @@ def forward(
489489
hidden_states: torch.Tensor,
490490
router_logits: torch.Tensor,
491491
) -> tuple[torch.Tensor, torch.Tensor]:
492-
shared_out, fused_out = AscendFusedMoE.forward(
493-
self,
494-
hidden_states=hidden_states,
495-
router_logits=router_logits,
496-
)
492+
if self._shared_experts is None:
493+
fused_out = AscendFusedMoE.forward(
494+
self,
495+
hidden_states=hidden_states,
496+
router_logits=router_logits,
497+
)
498+
shared_out = None
499+
else:
500+
shared_out, fused_out = AscendFusedMoE.forward(
501+
self,
502+
hidden_states=hidden_states,
503+
router_logits=router_logits,
504+
)
497505
return shared_out, fused_out
498506

499507
def forward_impl(self, hidden_states: torch.Tensor,
@@ -507,7 +515,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
507515
# Use a separate stream to run shared experts.
508516
# Note that currently we only support calculations in separate streams with aclgraph.
509517
# Communication operations in another stream might cause unknown errors.
510-
shared_out = self._shared_experts(hidden_states)
518+
if self._shared_experts is None:
519+
shared_out = None
520+
else:
521+
shared_out = self._shared_experts(hidden_states)
511522

512523
fused_output = AscendFusedMoE.forward_impl(
513524
self,
@@ -521,7 +532,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
521532
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
522533
forward_context = get_forward_context()
523534
moe_comm_type = forward_context.moe_comm_type
524-
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
525-
and not shared_expert_dp_enabled():
535+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
536+
and not shared_expert_dp_enabled() and shared_out is not None:
526537
shared_out = tensor_model_parallel_all_reduce(shared_out)
527-
return shared_out, fused_output
538+
if shared_out is None:
539+
return fused_output
540+
else:
541+
return shared_out, fused_output

vllm_ascend/ops/fused_moe/moe_mlp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,4 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
299299
group_list_type=group_list_type,
300300
topk_scales=topk_scales,
301301
need_trans=need_trans)
302+

vllm_ascend/patch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,4 @@
138138
# Future Plan:
139139
# Remove this patch when adapted vllm version contains the above PR.
140140
#
141+
from vllm_ascend.patch.worker import patch_deepseekv3

0 commit comments

Comments
 (0)