Skip to content

Commit 847d12a

Browse files
authored
[BugFix]Fix moe load problems in torchair when using dynamic eplb (#3381)
### What this PR does / why we need it? When using dynamic eplb, moe load is not imported. We fix this problem by modifying the return value of hidden states in torchair. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? DeepseekV3 in A3. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: daishixun <[email protected]>
1 parent cd69385 commit 847d12a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,13 +1279,15 @@ def forward(self,
12791279
)
12801280

12811281
if shared_experts:
1282-
if isinstance(e_hidden_states, tuple):
1282+
if isinstance(e_hidden_states,
1283+
tuple) and len(e_hidden_states) == 2:
12831284
e_hidden_states, shared_hidden_states = e_hidden_states
12841285

12851286
if self.dynamic_eplb and isinstance(
12861287
e_hidden_states, tuple) and len(e_hidden_states) == 3:
1287-
self.moe_load += e_hidden_states[2] if e_hidden_states[1] == 0 else \
1288-
torch.cat(e_hidden_states[2][:1], e_hidden_states[2][1:] - e_hidden_states[2][:-1])
1288+
e_hidden_states, group_list_type, expert_tokens = e_hidden_states
1289+
self.moe_load += expert_tokens if group_list_type else \
1290+
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
12891291

12901292
if (fused_moe_state not in [
12911293
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,

vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def torchair_fused_experts_with_mc2(
220220
shared_dequant_scale: Optional[Any] = None,
221221
w1_scale_bias: torch.Tensor = None,
222222
w2_scale_bias: torch.Tensor = None,
223+
dynamic_eplb: bool = False,
223224
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
224225
assert mc2_mask is not None
225226
if log2phy is not None:
@@ -354,6 +355,9 @@ def torchair_fused_experts_with_mc2(
354355
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
355356
**kwargs_mc2)
356357

358+
if dynamic_eplb:
359+
return (hidden_states, 1, expert_token_nums)
360+
357361
if shared_experts is None:
358362
return hidden_states
359363
else:
@@ -832,6 +836,7 @@ def __init__(self):
832836
self.ep_group = get_ep_group()
833837

834838
ascend_config = get_ascend_config()
839+
self.dynamic_eplb = ascend_config.dynamic_eplb
835840
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
836841
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
837842

@@ -994,7 +999,8 @@ def apply(
994999
is_torchair=self.torchair_graph_enabled,
9951000
mc2_mask=kwargs.get("mc2_mask", None),
9961001
shared_gate_up=shared_gate_up,
997-
shared_dequant_scale=shared_dequant_scale)
1002+
shared_dequant_scale=shared_dequant_scale,
1003+
dynamic_eplb=self.dynamic_eplb)
9981004
elif fused_moe_state in [
9991005
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
10001006
]:

0 commit comments

Comments
 (0)