Skip to content

Commit 14ffe4a

Browse files
committed
Remove PanguProMoEV1 related code
Signed-off-by: weichen <[email protected]>
1 parent 4144376 commit 14ffe4a

File tree

3 files changed

+5
-81
lines changed

3 files changed

+5
-81
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ def select_moe_comm_method(num_tokens: int,
240240
quant_type = getattr(
241241
vllm_config.model_config.hf_config, 'moe_quantize',
242242
getattr(vllm_config.model_config.hf_config, 'quantize', None))
243-
model_type = vllm_config.model_config.hf_config.model_type
244243

245244
if not vllm_config.parallel_config.enable_expert_parallel:
246245
moe_comm_type = MoECommType.ALLGATHER
@@ -267,7 +266,4 @@ def select_moe_comm_method(num_tokens: int,
267266
if fused_all2all_enable else MoECommType.ALLTOALL)
268267
else:
269268
raise ValueError(f"Unsupported soc_version: {soc_version}")
270-
# PanguProMoE only supports allgather
271-
if model_type == "PanguProMoE":
272-
moe_comm_type = MoECommType.ALLGATHER
273269
return moe_comm_type

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any, Dict, Optional
2020

2121
import torch
22-
from vllm.config import get_current_vllm_config
2322
from vllm.forward_context import get_forward_context
2423
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
2524

@@ -30,7 +29,7 @@
3029
PrepareAndFinalizeWithMC2, QuantType)
3130
from vllm_ascend.ops.fused_moe.token_dispatcher import (
3231
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
33-
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
32+
TokenDispatcherWithMC2)
3433

3534
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
3635

@@ -52,8 +51,6 @@ class MoECommMethod(ABC):
5251
"""Base class for MoE communication methods."""
5352

5453
def __init__(self, moe_config: FusedMoEConfig):
55-
self.model_type = get_current_vllm_config(
56-
).model_config.hf_config.model_type
5754
self.moe_config = moe_config
5855

5956
self.token_dispatcher = self._get_token_dispatcher()
@@ -198,16 +195,10 @@ class AllGatherCommImpl(MoECommMethod):
198195
"""
199196

200197
def _get_token_dispatcher(self):
201-
if self.model_type == "PanguProMoE":
202-
return TokenDispatcherWithMoge(
203-
top_k=self.moe_config.experts_per_token,
204-
num_experts=self.moe_config.num_experts,
205-
num_local_experts=self.moe_config.num_local_experts)
206-
else:
207-
return TokenDispatcherWithAllGather(
208-
top_k=self.moe_config.experts_per_token,
209-
num_experts=self.moe_config.num_experts,
210-
num_local_experts=self.moe_config.num_local_experts)
198+
return TokenDispatcherWithAllGather(
199+
top_k=self.moe_config.experts_per_token,
200+
num_experts=self.moe_config.num_experts,
201+
num_local_experts=self.moe_config.num_local_experts)
211202

212203
def _get_prepare_finalize(self):
213204
return PrepareAndFinalizeWithAllGather(self.moe_config)

vllm_ascend/ops/fused_moe/token_dispatcher.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -422,69 +422,6 @@ def token_combine(self,
422422
return final_hidden_states
423423

424424

425-
# mypy: disable-error-code="override"
426-
class TokenDispatcherWithMoge(MoETokenDispatcher):
427-
428-
def __init__(self, **kwargs):
429-
super().__init__(**kwargs)
430-
self.apply_router_weight_on_input = False
431-
self.local_num_experts = self.num_experts // self.ep_size
432-
self.local_num_group = self.top_k // self.ep_size
433-
self.bsz = None
434-
435-
def token_dispatch(self,
436-
hidden_states: torch.Tensor,
437-
topk_weights: torch.Tensor,
438-
topk_ids: torch.Tensor,
439-
expert_map: Optional[torch.Tensor] = None,
440-
log2phy: Optional[torch.Tensor] = None,
441-
global_redundant_expert_num: int = 0,
442-
shared_experts: Optional[Any] = None,
443-
quantized_x_for_share: Optional[Any] = None,
444-
dynamic_scale_for_share: Optional[Any] = None,
445-
mc2_mask: Optional[torch.Tensor] = None,
446-
apply_router_weight_on_input: bool = False,
447-
with_quant: bool = False,
448-
dynamic_eplb: bool = False,
449-
pertoken_scale: Optional[torch.Tensor] = None):
450-
self.bsz, _ = hidden_states.shape
451-
flatten_topk_ids = topk_ids.view(-1)
452-
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
453-
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
454-
sorted_hidden_states = hidden_states.index_select(
455-
0, self.sorted_topk_ids // self.local_num_group)
456-
457-
experts_id = torch.arange(0,
458-
self.local_num_experts,
459-
dtype=topk_ids.dtype,
460-
device=topk_ids.device)
461-
num_tokens_per_expert = (
462-
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
463-
torch.float32).sum(0)
464-
topk_scales = topk_weights.view(-1).index_select(
465-
0, self.sorted_topk_ids).unsqueeze(-1)
466-
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
467-
group_list_type = 0
468-
return {
469-
"group_list_type": group_list_type,
470-
"hidden_states": sorted_hidden_states,
471-
"group_list": group_list,
472-
"topk_scales": topk_scales
473-
}
474-
475-
def token_combine(self,
476-
hidden_states: torch.Tensor,
477-
context_metadata: dict,
478-
bias: torch.Tensor = None):
479-
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
480-
torch.int32)
481-
unsorted_hidden_states = hidden_states.index_select(
482-
0, unsorted_topk_ids)
483-
final_hidden_states = unsorted_hidden_states.reshape(
484-
self.bsz, self.top_k // self.ep_size, -1).sum(1)
485-
return final_hidden_states
486-
487-
488425
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
489426
"""
490427
The implementation of the AlltoAll-based token dispatcher, which handles token

0 commit comments

Comments
 (0)