Skip to content

Commit 94dd832

Browse files
[MoE] [Refactor] Combine common_fused_moe and fused_moe (#3176)
### What this PR does / why we need it? 1. Move additional functionalities from fused_moe.py to common_fused_moe.py and remove fused_moe.py 2. Remove unnecessary custom classes from qwen3_moe.py, and it will be completely removed after we release vllm-ascend v0.11.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing: 1. Enable/Disable EP 3. Aclgraph & eager 4. SP - vLLM version: v0.11.0 --------- Signed-off-by: Pr0Wh1teGivee <[email protected]> Co-authored-by: weijinqian0 <[email protected]>
1 parent a36e3da commit 94dd832

17 files changed

+184
-1119
lines changed

tests/ut/models/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def mock_distributed():
9696
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
9797
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
9898
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
99-
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
99+
patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
100100
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
101101
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
102102
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,

tests/ut/models/test_qwen3_moe.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

tests/ut/ops/test_fused_moe_prepare_and_finalize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def setUp(self):
2121
self.moe_config.tp_size = 1
2222
self.moe_config.ep_size = 1
2323
self.moe_config.dp_group = MagicMock()
24+
self.moe_config.original_num_experts = 8
2425

2526
@patch(
2627
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
@@ -196,7 +197,6 @@ def mock_all_gather_func(tensor, dim):
196197

197198
h_out, r_out, _ = layer.prepare(hidden_states,
198199
router_logits,
199-
rm_router_logits=False,
200200
gate=mock_gate)
201201

202202
# After all-gather with DP=2, should double the batch size
@@ -265,7 +265,6 @@ def mock_all_reduce(tensor):
265265
# Run prepare
266266
h_out, r_out, _ = layer.prepare(hidden_states,
267267
router_logits,
268-
rm_router_logits=False,
269268
gate=mock_gate)
270269

271270
# Should be global tensor: [7, 8] and [7, 2]

tests/ut/ops/test_fused_ops.py

Lines changed: 9 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525
from tests.ut.base import TestBase
2626
from vllm_ascend.ascend_forward_context import MoECommType
27-
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
28-
AscendUnquantizedFusedMoEMethod)
27+
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
2928
from vllm_ascend.ops.moe.experts_selector import select_experts
3029
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
3130
from vllm_ascend.utils import AscendSocVersion, adapt_patch
@@ -70,7 +69,7 @@ def setup_vllm_config_mock(mocker: MockerFixture):
7069
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
7170
mock_vllm_config.model_config.max_model_len = 2048
7271

73-
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
72+
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
7473
return_value=mock_vllm_config)
7574
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
7675
return_value=mock_vllm_config)
@@ -104,24 +103,24 @@ def mock_finalize(hidden_states, **kwargs):
104103

105104
with patch('torch.distributed.get_rank', return_value=0), \
106105
patch('torch.distributed.get_world_size', return_value=4), \
107-
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
106+
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
108107
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
109-
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
110-
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
108+
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
109+
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
111110
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
112-
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
111+
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
113112
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
114113
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
115114
return_value=mock_dp_and_tp_group(mocker)), \
116-
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
115+
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
117116
return_value=MagicMock(
118117
torchair_graph_config=MagicMock(enabled=False),
119118
enable_multistream_moe=False,
120119
expert_map_path=None
121120
)), \
122-
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
121+
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
123122
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
124-
patch('vllm_ascend.ops.fused_moe.get_forward_context',
123+
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
125124
return_value=mock_forward_context_obj), \
126125
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
127126
return_value=mock_forward_context_obj), \
@@ -252,196 +251,6 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module):
252251
pass
253252

254253

255-
class TestAscendFusedMoe:
256-
257-
def test_init_no_quant(self, mock_dist_env, default_moe_config):
258-
layer = AscendFusedMoE(**default_moe_config)
259-
260-
layer.w13_weight = nn.Parameter(
261-
torch.randn(default_moe_config['num_experts'],
262-
default_moe_config['intermediate_size'] * 2,
263-
default_moe_config['hidden_size']))
264-
layer.w2_weight = nn.Parameter(
265-
torch.randn(default_moe_config['num_experts'],
266-
default_moe_config['hidden_size'],
267-
default_moe_config['intermediate_size']))
268-
269-
assert layer.num_experts == default_moe_config['num_experts']
270-
assert layer.top_k == default_moe_config['top_k']
271-
assert hasattr(layer, 'w13_weight')
272-
assert hasattr(layer, 'w2_weight')
273-
274-
with pytest.raises(AssertionError):
275-
error_config = default_moe_config.copy()
276-
error_config['use_grouped_topk'] = True
277-
layer = AscendFusedMoE(**error_config)
278-
279-
with pytest.raises(ValueError):
280-
error_config = default_moe_config.copy()
281-
error_config['scoring_func'] = "random"
282-
layer = AscendFusedMoE(**error_config)
283-
284-
def test_init_with_quant(self, mock_dist_env, default_moe_config):
285-
mock_quant_config = MagicMock()
286-
mock_quant_method = MockFusedMoEMethod()
287-
mock_quant_config.get_quant_method.return_value = mock_quant_method
288-
289-
moe = AscendFusedMoE(**default_moe_config,
290-
quant_config=mock_quant_config)
291-
292-
assert moe.quant_method is not None
293-
assert moe.quant_method == mock_quant_method
294-
295-
@pytest.mark.parametrize(
296-
"others_param",
297-
[[None,
298-
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
299-
[2, None, False, 5, None], [None, None, True, 5, None],
300-
[None, None, False, 1, None], [None, None, True, 5, 1],
301-
[None, None, False, 5, 1]])
302-
def test_forward(self, mock_dist_env, default_moe_config, others_param):
303-
304-
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
305-
inputs = torch.randn(num_tokens, 32)
306-
router_logits = torch.randn(num_tokens, 8)
307-
moe = AscendFusedMoE(**default_moe_config)
308-
309-
if ep_size == 1:
310-
moe.moe_parallel_config.ep_size = 1
311-
312-
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
313-
forward_context = mock_dist_env['mock_forward_context_obj']
314-
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
315-
return_value=forward_context):
316-
output = moe.forward(inputs,
317-
router_logits,
318-
is_prefill=is_prefill,
319-
top_k=top_k,
320-
shared_experts=shared_experts)
321-
322-
moe.quant_method.apply.assert_called_once()
323-
324-
if shared_experts:
325-
assert output[0].shape == (num_tokens, 32)
326-
assert output[1].shape == (num_tokens, 10)
327-
else:
328-
assert output.shape == (num_tokens, 32)
329-
330-
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
331-
default_moe_config):
332-
inputs = torch.randn(5, 32)
333-
router_logits = torch.randn(5, 8)
334-
moe = AscendFusedMoE(**default_moe_config)
335-
336-
moe.quant_method = MockQuantMethod(None, 5)
337-
output = moe._forward_ms_fused_moe_comp(inputs,
338-
router_logits,
339-
is_prefill=False,
340-
real_top_k=1)
341-
342-
moe.quant_method.apply.assert_called_once()
343-
344-
assert output.shape == (5, 32)
345-
346-
347-
class TestAscendUnquantizedFusedMoEMethod:
348-
349-
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
350-
layer = MagicMock()
351-
layer.w13_weight.data = torch.randn(16, 32)
352-
layer.w2_weight.data = torch.randn(16, 32)
353-
354-
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
355-
patch('vllm_ascend.utils.is_310p', return_value=False):
356-
moe_method.process_weights_after_loading(layer)
357-
358-
assert isinstance(layer.w13_weight, torch.nn.Parameter)
359-
assert isinstance(layer.w2_weight, torch.nn.Parameter)
360-
assert not layer.w13_weight.requires_grad
361-
assert not layer.w2_weight.requires_grad
362-
363-
@pytest.mark.parametrize("others_param",
364-
[[256, 4], [128, 1], [128, 1], [128, 4]])
365-
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
366-
mock_moe_env, others_param):
367-
global_num_experts, ep_size = others_param
368-
is_prefill = False
369-
370-
forward_context = mock_dist_env['mock_forward_context_obj']
371-
372-
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
373-
return_value=forward_context):
374-
moe_method.ep_size = ep_size
375-
x = torch.randn(8, 2, 2)
376-
router_logits = torch.randn(8, 8)
377-
layer = MagicMock()
378-
local_num_experts = 2
379-
hidden_size = 2
380-
intermediate_size_per_partition = 4
381-
382-
layer.w13_weight = torch.randn(local_num_experts,
383-
intermediate_size_per_partition * 2,
384-
hidden_size)
385-
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
386-
intermediate_size_per_partition)
387-
388-
result = moe_method.apply(layer=layer,
389-
x=x,
390-
router_logits=router_logits,
391-
top_k=2,
392-
renormalize=True,
393-
global_num_experts=global_num_experts,
394-
is_prefill=is_prefill)
395-
396-
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
397-
mock_moe_comm_method.fused_experts.assert_called_once()
398-
399-
expected_shape = (16, 2)
400-
assert result.shape == expected_shape
401-
402-
@pytest.mark.parametrize("others_param", [16, 1, 4])
403-
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
404-
mock_moe_env, others_param):
405-
ep_size = others_param
406-
is_prefill = False
407-
408-
forward_context = mock_dist_env['mock_forward_context_obj']
409-
410-
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
411-
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
412-
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
413-
moe_method.ep_size = ep_size
414-
x = torch.randn(8, 2, 2)
415-
if ep_size == 1:
416-
x = x.view(-1, 2)
417-
router_logits = torch.randn(8, 8)
418-
layer = MagicMock()
419-
420-
local_num_experts = 2
421-
hidden_size = 2
422-
intermediate_size_per_partition = 4
423-
layer.w13_weight = torch.randn(local_num_experts,
424-
intermediate_size_per_partition * 2,
425-
hidden_size)
426-
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
427-
intermediate_size_per_partition)
428-
429-
result = moe_method.apply(layer=layer,
430-
x=x,
431-
router_logits=router_logits,
432-
top_k=2,
433-
renormalize=True,
434-
global_num_experts=128,
435-
expert_map=expert_map,
436-
is_prefill=is_prefill)
437-
438-
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
439-
mock_moe_comm_method.fused_experts.assert_called_once()
440-
441-
expected_shape = (16, 2)
442-
assert result.shape == expected_shape
443-
444-
445254
class TestExpertsSelector:
446255

447256
@pytest.mark.parametrize("global_num_experts", [[256], [128]])

tests/ut/ops/test_moe_comm_method.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_all_gather_comm_impl(self, mock_token_dispatcher,
6363

6464
# Verify prepare was called with correct arguments
6565
mock_pf_instance.prepare.assert_called_once_with(
66-
hidden_states, router_logits, False, False, False, None)
66+
hidden_states, router_logits, False, False, None)
6767

6868
# Test finalize method
6969
comm_impl.finalize(h_out, reduce_results=True)
@@ -108,7 +108,7 @@ def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
108108

109109
# Verify prepare was called with correct arguments
110110
mock_pf_instance.prepare.assert_called_once_with(
111-
hidden_states, router_logits, False, False, False, None)
111+
hidden_states, router_logits, False, False, None)
112112

113113
# Test finalize method
114114
comm_impl.finalize(h_out, reduce_results=True)
@@ -153,7 +153,7 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
153153

154154
# Verify prepare was called with correct arguments
155155
mock_pf_instance.prepare.assert_called_once_with(
156-
hidden_states, router_logits, False, False, False, None)
156+
hidden_states, router_logits, False, False, None)
157157

158158
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
159159
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")

vllm_ascend/models/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,6 @@ def register_model():
4545
"DeepSeekMTPModel",
4646
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
4747

48-
ModelRegistry.register_model(
49-
"Qwen3MoeForCausalLM",
50-
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
51-
5248
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
5349
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
5450
ModelRegistry.register_model(

vllm_ascend/models/deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from vllm_ascend.models.layers.mla import AscendMLAModules
6363
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
6464
AscendSparseFlashAttention, Indexer)
65-
from vllm_ascend.ops.fused_moe import AscendFusedMoE
65+
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
6666

6767

6868
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):

0 commit comments

Comments
 (0)