Skip to content

Commit f0060fc

Browse files
[Pangu][MoE] Remove PanguProMoEV1 related code (#5088)
### What this PR does / why we need it? PanguProMoEV1 is no longer supported in vllm-ascend, remove related code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? e2e & ut - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: weichen <[email protected]>
1 parent 3f7a2fb commit f0060fc

File tree

5 files changed

+9
-108
lines changed

5 files changed

+9
-108
lines changed

tests/ut/ops/test_fused_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ def setup_vllm_config_mock(mocker: MockerFixture):
7272

7373
mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config',
7474
return_value=mock_vllm_config)
75-
mocker.patch(
76-
'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config',
77-
return_value=mock_vllm_config)
7875

7976

8077
@pytest.fixture

tests/ut/ops/test_moe_comm_method.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def setUp(self):
2626
self.moe_config.dp_group = MagicMock()
2727
self.moe_config.num_global_redundant_experts = 0
2828

29-
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
3029
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
3130
@patch(
3231
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
@@ -36,11 +35,7 @@ def setUp(self):
3635
)
3736
def test_all_gather_comm_impl(self, mock_token_dispatcher,
3837
mock_prepare_finalize,
39-
mock_get_forward_context,
40-
mock_get_current_vllm_config):
41-
# Mock vLLM config
42-
mock_get_current_vllm_config.return_value = MagicMock()
43-
38+
mock_get_forward_context):
4439
# Mock forward context
4540
mock_context = MagicMock()
4641
mock_context.moe_comm_method = "all_gather"
@@ -76,17 +71,12 @@ def test_all_gather_comm_impl(self, mock_token_dispatcher,
7671
context_metadata=context_metadata)
7772
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
7873

79-
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
8074
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
8175
@patch(
8276
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2")
8377
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2")
8478
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
85-
mock_get_forward_context,
86-
mock_get_current_vllm_config):
87-
# Mock vLLM config
88-
mock_get_current_vllm_config.return_value = MagicMock()
89-
79+
mock_get_forward_context):
9080
# Mock forward context
9181
mock_context = MagicMock()
9282
mock_context.moe_comm_method = "mc2"
@@ -124,7 +114,6 @@ def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
124114
context_metadata=context_metadata)
125115
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
126116

127-
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
128117
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
129118
@patch(
130119
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All"
@@ -134,11 +123,7 @@ def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
134123
)
135124
def test_alltoall_comm_impl(self, mock_token_dispatcher,
136125
mock_prepare_finalize,
137-
mock_get_forward_context,
138-
mock_get_current_vllm_config):
139-
# Mock vLLM config
140-
mock_get_current_vllm_config.return_value = MagicMock()
141-
126+
mock_get_forward_context):
142127
# Mock forward context
143128
mock_context = MagicMock()
144129
mock_context.moe_comm_method = "alltoall"
@@ -168,7 +153,6 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
168153
mock_pf_instance.prepare.assert_called_once_with(
169154
hidden_states, router_logits, False, False, QuantType.NONE)
170155

171-
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config")
172156
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context")
173157
@patch(
174158
"vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather"
@@ -179,11 +163,7 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
179163
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
180164
def test_fused_experts_method(self, mock_unified_apply_mlp,
181165
mock_token_dispatcher, mock_prepare_finalize,
182-
mock_get_forward_context,
183-
mock_get_current_vllm_config):
184-
# Mock vLLM config
185-
mock_get_current_vllm_config.return_value = MagicMock()
186-
166+
mock_get_forward_context):
187167
# Mock forward context
188168
mock_context = MagicMock()
189169
mock_context.moe_comm_method = "all_gather"

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ def select_moe_comm_method(num_tokens: int,
240240
quant_type = getattr(
241241
vllm_config.model_config.hf_config, 'moe_quantize',
242242
getattr(vllm_config.model_config.hf_config, 'quantize', None))
243-
model_type = vllm_config.model_config.hf_config.model_type
244243

245244
if not vllm_config.parallel_config.enable_expert_parallel:
246245
moe_comm_type = MoECommType.ALLGATHER
@@ -267,7 +266,4 @@ def select_moe_comm_method(num_tokens: int,
267266
if fused_all2all_enable else MoECommType.ALLTOALL)
268267
else:
269268
raise ValueError(f"Unsupported soc_version: {soc_version}")
270-
# PanguProMoE only supports allgather
271-
if model_type == "PanguProMoE":
272-
moe_comm_type = MoECommType.ALLGATHER
273269
return moe_comm_type

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any, Dict, Optional
2020

2121
import torch
22-
from vllm.config import get_current_vllm_config
2322
from vllm.forward_context import get_forward_context
2423
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
2524

@@ -30,7 +29,7 @@
3029
PrepareAndFinalizeWithMC2, QuantType)
3130
from vllm_ascend.ops.fused_moe.token_dispatcher import (
3231
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
33-
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
32+
TokenDispatcherWithMC2)
3433

3534
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
3635

@@ -52,8 +51,6 @@ class MoECommMethod(ABC):
5251
"""Base class for MoE communication methods."""
5352

5453
def __init__(self, moe_config: FusedMoEConfig):
55-
self.model_type = get_current_vllm_config(
56-
).model_config.hf_config.model_type
5754
self.moe_config = moe_config
5855

5956
self.token_dispatcher = self._get_token_dispatcher()
@@ -198,16 +195,10 @@ class AllGatherCommImpl(MoECommMethod):
198195
"""
199196

200197
def _get_token_dispatcher(self):
201-
if self.model_type == "PanguProMoE":
202-
return TokenDispatcherWithMoge(
203-
top_k=self.moe_config.experts_per_token,
204-
num_experts=self.moe_config.num_experts,
205-
num_local_experts=self.moe_config.num_local_experts)
206-
else:
207-
return TokenDispatcherWithAllGather(
208-
top_k=self.moe_config.experts_per_token,
209-
num_experts=self.moe_config.num_experts,
210-
num_local_experts=self.moe_config.num_local_experts)
198+
return TokenDispatcherWithAllGather(
199+
top_k=self.moe_config.experts_per_token,
200+
num_experts=self.moe_config.num_experts,
201+
num_local_experts=self.moe_config.num_local_experts)
211202

212203
def _get_prepare_finalize(self):
213204
return PrepareAndFinalizeWithAllGather(self.moe_config)

vllm_ascend/ops/fused_moe/token_dispatcher.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -422,69 +422,6 @@ def token_combine(self,
422422
return final_hidden_states
423423

424424

425-
# mypy: disable-error-code="override"
426-
class TokenDispatcherWithMoge(MoETokenDispatcher):
427-
428-
def __init__(self, **kwargs):
429-
super().__init__(**kwargs)
430-
self.apply_router_weight_on_input = False
431-
self.local_num_experts = self.num_experts // self.ep_size
432-
self.local_num_group = self.top_k // self.ep_size
433-
self.bsz = None
434-
435-
def token_dispatch(self,
436-
hidden_states: torch.Tensor,
437-
topk_weights: torch.Tensor,
438-
topk_ids: torch.Tensor,
439-
expert_map: Optional[torch.Tensor] = None,
440-
log2phy: Optional[torch.Tensor] = None,
441-
global_redundant_expert_num: int = 0,
442-
shared_experts: Optional[Any] = None,
443-
quantized_x_for_share: Optional[Any] = None,
444-
dynamic_scale_for_share: Optional[Any] = None,
445-
mc2_mask: Optional[torch.Tensor] = None,
446-
apply_router_weight_on_input: bool = False,
447-
with_quant: bool = False,
448-
dynamic_eplb: bool = False,
449-
pertoken_scale: Optional[torch.Tensor] = None):
450-
self.bsz, _ = hidden_states.shape
451-
flatten_topk_ids = topk_ids.view(-1)
452-
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
453-
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
454-
sorted_hidden_states = hidden_states.index_select(
455-
0, self.sorted_topk_ids // self.local_num_group)
456-
457-
experts_id = torch.arange(0,
458-
self.local_num_experts,
459-
dtype=topk_ids.dtype,
460-
device=topk_ids.device)
461-
num_tokens_per_expert = (
462-
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
463-
torch.float32).sum(0)
464-
topk_scales = topk_weights.view(-1).index_select(
465-
0, self.sorted_topk_ids).unsqueeze(-1)
466-
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
467-
group_list_type = 0
468-
return {
469-
"group_list_type": group_list_type,
470-
"hidden_states": sorted_hidden_states,
471-
"group_list": group_list,
472-
"topk_scales": topk_scales
473-
}
474-
475-
def token_combine(self,
476-
hidden_states: torch.Tensor,
477-
context_metadata: dict,
478-
bias: torch.Tensor = None):
479-
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
480-
torch.int32)
481-
unsorted_hidden_states = hidden_states.index_select(
482-
0, unsorted_topk_ids)
483-
final_hidden_states = unsorted_hidden_states.reshape(
484-
self.bsz, self.top_k // self.ep_size, -1).sum(1)
485-
return final_hidden_states
486-
487-
488425
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
489426
"""
490427
The implementation of the AlltoAll-based token dispatcher, which handles token

0 commit comments

Comments
 (0)