Skip to content

Commit c506ba6

Browse files
[v0.11.0] [Bugfix] [MoE]fix error in deepseek when using allgather (#3827)
### What this PR does / why we need it? After refactoring vllm_ascend/models and FusedMoE, we are unable to pass `gate` from deepseekv2.py to `AscendFusedMoE.forward`, which will result in error when running deepseek v3/r1 with allgather. Hence, this pr removes `gate` related computations from FusedMoE module in eager/aclgraph mode. ### Does this PR introduce _any_ user-facing change? `rm_router_logits` is deprecated in eager/aclgraph. ### How was this patch tested? e2e & ut Signed-off-by: Pr0Wh1teGivee <[email protected]>
1 parent 211d4b9 commit c506ba6

File tree

7 files changed

+98
-115
lines changed

7 files changed

+98
-115
lines changed

tests/ut/ops/test_fused_moe_prepare_and_finalize.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,7 @@ def mock_all_gather_func(tensor, dim):
191191
hidden_states = torch.randn(3, 8)
192192
router_logits = torch.randn(3, 2)
193193

194-
# Mock the gate function for rm_router_logits=False case
195-
mock_gate = MagicMock()
196-
mock_gate.return_value = (router_logits.repeat(2, 1), None)
197-
198-
h_out, r_out, _ = layer.prepare(hidden_states,
199-
router_logits,
200-
gate=mock_gate)
194+
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
201195

202196
# After all-gather with DP=2, should double the batch size
203197
self.assertEqual(h_out.shape[0], 12)
@@ -258,14 +252,8 @@ def mock_all_reduce(tensor):
258252
hidden_states = torch.randn(3, 8)
259253
router_logits = torch.randn(3, 2)
260254

261-
# Mock gate for router logits recomputation
262-
mock_gate = MagicMock()
263-
mock_gate.return_value = (torch.randn(7, 2), None)
264-
265255
# Run prepare
266-
h_out, r_out, _ = layer.prepare(hidden_states,
267-
router_logits,
268-
gate=mock_gate)
256+
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
269257

270258
# Should be global tensor: [7, 8] and [7, 2]
271259
self.assertEqual(h_out.shape, (7, 8))

tests/ut/ops/test_moe_comm_method.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_all_gather_comm_impl(self, mock_token_dispatcher,
6363

6464
# Verify prepare was called with correct arguments
6565
mock_pf_instance.prepare.assert_called_once_with(
66-
hidden_states, router_logits, False, False, None)
66+
hidden_states, router_logits, False, False)
6767

6868
# Test finalize method
6969
comm_impl.finalize(h_out, reduce_results=True)
@@ -108,7 +108,7 @@ def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
108108

109109
# Verify prepare was called with correct arguments
110110
mock_pf_instance.prepare.assert_called_once_with(
111-
hidden_states, router_logits, False, False, None)
111+
hidden_states, router_logits, False, False)
112112

113113
# Test finalize method
114114
comm_impl.finalize(h_out, reduce_results=True)
@@ -153,7 +153,7 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
153153

154154
# Verify prepare was called with correct arguments
155155
mock_pf_instance.prepare.assert_called_once_with(
156-
hidden_states, router_logits, False, False, None)
156+
hidden_states, router_logits, False, False)
157157

158158
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
159159
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")

vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.forward_context import get_forward_context
2727
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
2828

29-
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
29+
from vllm_ascend.utils import enable_sp
3030

3131

3232
class FusedMoEPrepareAndFinalize(ABC):
@@ -43,31 +43,26 @@ class FusedMoEPrepareAndFinalize(ABC):
4343

4444
def __init__(self, moe_config: FusedMoEConfig):
4545
self.moe_config = moe_config
46-
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
47-
self.rm_router_logits = get_rm_router_logits_state(
48-
self.moe_config.ep_size, self.moe_config.dp_size,
49-
is_deepseek_v3_r1)
5046

5147
@abstractmethod
52-
def prepare(self,
53-
hidden_states: torch.Tensor,
54-
router_logits: torch.Tensor,
55-
enable_shared_expert_dp: bool = False,
56-
replace_allreduce: bool = False,
57-
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
48+
def prepare(
49+
self,
50+
hidden_states: torch.Tensor,
51+
router_logits: torch.Tensor,
52+
enable_shared_expert_dp: bool = False,
53+
replace_allreduce: bool = False
54+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
5855
"""
5956
Prepare tensors before MoE computation. May involve:
6057
- Padding to align communication boundaries
6158
- Slicing across tensor-parallel ranks
6259
- Broadcasting across data-parallel ranks
63-
- Recomputing router logits if needed
6460
6561
Args:
6662
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
6763
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
6864
enable_shared_expert_dp (bool): Skip DP communication for shared experts
6965
replace_allreduce (bool): Bypass default all-reduce behavior
70-
gate (nn.Module, optional): Gate network to recompute router_logits if needed
7166
7267
Returns:
7368
Tuple of:
@@ -116,12 +111,13 @@ def _restore_tp_across_dp(self):
116111
self.tp_size = get_tensor_model_parallel_world_size()
117112
self.tp_rank = get_tensor_model_parallel_rank()
118113

119-
def prepare(self,
120-
hidden_states: torch.Tensor,
121-
router_logits: torch.Tensor,
122-
enable_shared_expert_dp: bool = False,
123-
replace_allreduce: bool = False,
124-
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
114+
def prepare(
115+
self,
116+
hidden_states: torch.Tensor,
117+
router_logits: torch.Tensor,
118+
enable_shared_expert_dp: bool = False,
119+
replace_allreduce: bool = False
120+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125121
"""
126122
Preparation steps:
127123
1. Fetch `mc2_mask` and target padding length from forward context.
@@ -214,12 +210,13 @@ def _restore_tp_across_dp(self):
214210
self.tp_size = get_tensor_model_parallel_world_size()
215211
self.tp_rank = get_tensor_model_parallel_rank()
216212

217-
def prepare(self,
218-
hidden_states: torch.Tensor,
219-
router_logits: torch.Tensor,
220-
enable_shared_expert_dp: bool = False,
221-
replace_allreduce: bool = False,
222-
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
213+
def prepare(
214+
self,
215+
hidden_states: torch.Tensor,
216+
router_logits: torch.Tensor,
217+
enable_shared_expert_dp: bool = False,
218+
replace_allreduce: bool = False
219+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
223220
"""
224221
Preparation steps:
225222
1. Pad hidden_states and router_logits to next multiple of TP size.
@@ -307,12 +304,13 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
307304
TP AG → Attn → TP RS → EP AG → MoE → EP RS
308305
"""
309306

310-
def prepare(self,
311-
hidden_states: torch.Tensor,
312-
router_logits: torch.Tensor,
313-
enable_shared_expert_dp: bool = False,
314-
replace_allreduce: bool = False,
315-
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
307+
def prepare(
308+
self,
309+
hidden_states: torch.Tensor,
310+
router_logits: torch.Tensor,
311+
enable_shared_expert_dp: bool = False,
312+
replace_allreduce: bool = False
313+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
316314
"""
317315
Preparation steps:
318316
AllGather hidden_states and router_logits to form global tensors.
@@ -325,7 +323,7 @@ def prepare(self,
325323

326324
return self._prepare_with_dp_group(hidden_states, router_logits,
327325
enable_shared_expert_dp,
328-
replace_allreduce, gate)
326+
replace_allreduce)
329327

330328
def _prepare_with_ep_group(
331329
self,
@@ -340,12 +338,12 @@ def _prepare_with_ep_group(
340338
return hidden_states, router_logits, None
341339

342340
def _prepare_with_dp_group(
343-
self,
344-
hidden_states: torch.Tensor,
345-
router_logits: torch.Tensor,
346-
enable_shared_expert_dp: bool = False,
347-
replace_allreduce: bool = False,
348-
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
341+
self,
342+
hidden_states: torch.Tensor,
343+
router_logits: torch.Tensor,
344+
enable_shared_expert_dp: bool = False,
345+
replace_allreduce: bool = False
346+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
349347
"""
350348
Preparation steps:
351349
1. Fetch max token count across DP group from forward context.
@@ -365,18 +363,14 @@ def _prepare_with_dp_group(
365363
if pad_size > 0:
366364
hidden_states = nn.functional.pad(hidden_states,
367365
(0, 0, 0, pad_size))
368-
if not self.rm_router_logits:
369-
router_logits = nn.functional.pad(router_logits,
370-
(0, 0, 0, pad_size))
366+
router_logits = nn.functional.pad(router_logits,
367+
(0, 0, 0, pad_size))
371368

372369
# All-gather across DP group
373370
hidden_states = self.moe_config.dp_group.all_gather(
374371
hidden_states, 0)
375-
if self.rm_router_logits:
376-
router_logits, _ = gate(hidden_states) # Recompute globally
377-
else:
378-
router_logits = self.moe_config.dp_group.all_gather(
379-
router_logits, 0)
372+
router_logits = self.moe_config.dp_group.all_gather(
373+
router_logits, 0)
380374

381375
return hidden_states, router_logits, None
382376

@@ -472,12 +466,13 @@ def _naive_multicast(self, x: torch.Tensor,
472466
get_dp_group().broadcast(buffer[start:end, :], idx)
473467
return buffer
474468

475-
def prepare(self,
476-
hidden_states: torch.Tensor,
477-
router_logits: torch.Tensor,
478-
enable_shared_expert_dp: bool = False,
479-
replace_allreduce: bool = False,
480-
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
469+
def prepare(
470+
self,
471+
hidden_states: torch.Tensor,
472+
router_logits: torch.Tensor,
473+
enable_shared_expert_dp: bool = False,
474+
replace_allreduce: bool = False
475+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
481476
"""
482477
Preparation steps:
483478
1. Fetch cumulative token boundaries from forward context.
@@ -493,11 +488,8 @@ def prepare(self,
493488
).dp_metadata.cu_tokens_across_sp(1)
494489
hidden_states = self._naive_multicast(hidden_states,
495490
self.cu_tokens_across_dp_cpu)
496-
if self.rm_router_logits:
497-
router_logits, _ = gate(hidden_states)
498-
else:
499-
router_logits = self._naive_multicast(
500-
router_logits, self.cu_tokens_across_dp_cpu)
491+
router_logits = self._naive_multicast(router_logits,
492+
self.cu_tokens_across_dp_cpu)
501493

502494
return hidden_states, router_logits, None
503495

vllm_ascend/ops/moe/moe_comm_method.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,16 @@ def __init__(self, moe_config: FusedMoEConfig):
6363
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
6464
)
6565

66-
def prepare(self,
67-
hidden_states: torch.Tensor,
68-
router_logits: torch.Tensor,
69-
enable_shared_expert_dp: bool = False,
70-
replace_allreduce: bool = False,
71-
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
66+
def prepare(
67+
self,
68+
hidden_states: torch.Tensor,
69+
router_logits: torch.Tensor,
70+
enable_shared_expert_dp: bool = False,
71+
replace_allreduce: bool = False
72+
) -> tuple[torch.Tensor, torch.Tensor]:
7273
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
7374
hidden_states, router_logits, enable_shared_expert_dp,
74-
replace_allreduce, gate)
75+
replace_allreduce)
7576
self.mc2_mask = mc2_mask
7677
return hidden_states, router_logits
7778

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@
4848
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4949
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
5050
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
51-
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
51+
from vllm_ascend.torchair.utils import (get_all_reduce_merge_state,
52+
get_rm_router_logits_state,
53+
npu_stream_switch, npu_wait_tensor,
5254
super_kernel)
5355
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
54-
get_all_reduce_merge_state,
55-
get_ascend_soc_version,
56-
get_rm_router_logits_state, is_310p,
56+
get_ascend_soc_version, is_310p,
5757
is_hierarchical_communication_enabled)
5858

5959

vllm_ascend/torchair/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
except ImportError:
1616
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
1717
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
18+
19+
import vllm_ascend.envs as envs_ascend
1820
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
1921

2022
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
@@ -241,3 +243,33 @@ def torchair_ops_patch():
241243

242244
def super_kernel(prefix: str, option: str, enabled: bool = True):
243245
return _super_kernel(prefix, option) if enabled else nullcontext()
246+
247+
248+
# TODO(ttanzhiqiang): rm_router_logits
249+
# dp>1 will trigger
250+
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
251+
def get_rm_router_logits_state(ep_size: int, dp_size: int,
252+
is_deepseek_v3_r1: bool):
253+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
254+
# only supports deepseek v3/r1
255+
if dp_size > 1:
256+
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
257+
and is_deepseek_v3_r1):
258+
return True
259+
elif ep_size == 1 and is_deepseek_v3_r1:
260+
return True
261+
return False
262+
263+
264+
# TODO(ttanzhiqiang): all_reduce merge
265+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
266+
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
267+
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
268+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
269+
# only supports deepseek v3/r1
270+
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
271+
and is_deepseek_v3_r1):
272+
return True
273+
elif ep_size == 1 and is_deepseek_v3_r1:
274+
return True
275+
return False

vllm_ascend/utils.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -520,36 +520,6 @@ def pop_captured_sync(self) -> dict:
520520
return durations
521521

522522

523-
# TODO(ttanzhiqiang): rm_router_logits
524-
# dp>1 will trigger
525-
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
526-
def get_rm_router_logits_state(ep_size: int, dp_size: int,
527-
is_deepseek_v3_r1: bool):
528-
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
529-
# only supports deepseek v3/r1
530-
if dp_size > 1:
531-
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
532-
and is_deepseek_v3_r1):
533-
return True
534-
elif ep_size == 1 and is_deepseek_v3_r1:
535-
return True
536-
return False
537-
538-
539-
# TODO(ttanzhiqiang): all_reduce merge
540-
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
541-
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
542-
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
543-
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
544-
# only supports deepseek v3/r1
545-
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
546-
and is_deepseek_v3_r1):
547-
return True
548-
elif ep_size == 1 and is_deepseek_v3_r1:
549-
return True
550-
return False
551-
552-
553523
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
554524
"""Register Ascend CustomOP
555525

0 commit comments

Comments
 (0)