Skip to content

Commit 3d00808

Browse files
committed
Remove PanguProMoEV1 related code
Signed-off-by: weichen <[email protected]>
1 parent 18d2395 commit 3d00808

File tree

3 files changed

+5
-79
lines changed

3 files changed

+5
-79
lines changed

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
PrepareAndFinalizeWithMC2, QuantType)
3131
from vllm_ascend.ops.fused_moe.token_dispatcher import (
3232
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
33-
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
33+
TokenDispatcherWithMC2)
3434

3535
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
3636

@@ -198,16 +198,10 @@ class AllGatherCommImpl(MoECommMethod):
198198
"""
199199

200200
def _get_token_dispatcher(self):
201-
if self.model_type == "PanguProMoE":
202-
return TokenDispatcherWithMoge(
203-
top_k=self.moe_config.experts_per_token,
204-
num_experts=self.moe_config.num_experts,
205-
num_local_experts=self.moe_config.num_local_experts)
206-
else:
207-
return TokenDispatcherWithAllGather(
208-
top_k=self.moe_config.experts_per_token,
209-
num_experts=self.moe_config.num_experts,
210-
num_local_experts=self.moe_config.num_local_experts)
201+
return TokenDispatcherWithAllGather(
202+
top_k=self.moe_config.experts_per_token,
203+
num_experts=self.moe_config.num_experts,
204+
num_local_experts=self.moe_config.num_local_experts)
211205

212206
def _get_prepare_finalize(self):
213207
return PrepareAndFinalizeWithAllGather(self.moe_config)

vllm_ascend/ops/fused_moe/token_dispatcher.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -422,69 +422,6 @@ def token_combine(self,
422422
return final_hidden_states
423423

424424

425-
# mypy: disable-error-code="override"
426-
class TokenDispatcherWithMoge(MoETokenDispatcher):
427-
428-
def __init__(self, **kwargs):
429-
super().__init__(**kwargs)
430-
self.apply_router_weight_on_input = False
431-
self.local_num_experts = self.num_experts // self.ep_size
432-
self.local_num_group = self.top_k // self.ep_size
433-
self.bsz = None
434-
435-
def token_dispatch(self,
436-
hidden_states: torch.Tensor,
437-
topk_weights: torch.Tensor,
438-
topk_ids: torch.Tensor,
439-
expert_map: Optional[torch.Tensor] = None,
440-
log2phy: Optional[torch.Tensor] = None,
441-
global_redundant_expert_num: int = 0,
442-
shared_experts: Optional[Any] = None,
443-
quantized_x_for_share: Optional[Any] = None,
444-
dynamic_scale_for_share: Optional[Any] = None,
445-
mc2_mask: Optional[torch.Tensor] = None,
446-
apply_router_weight_on_input: bool = False,
447-
with_quant: bool = False,
448-
dynamic_eplb: bool = False,
449-
pertoken_scale: Optional[torch.Tensor] = None):
450-
self.bsz, _ = hidden_states.shape
451-
flatten_topk_ids = topk_ids.view(-1)
452-
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
453-
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
454-
sorted_hidden_states = hidden_states.index_select(
455-
0, self.sorted_topk_ids // self.local_num_group)
456-
457-
experts_id = torch.arange(0,
458-
self.local_num_experts,
459-
dtype=topk_ids.dtype,
460-
device=topk_ids.device)
461-
num_tokens_per_expert = (
462-
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
463-
torch.float32).sum(0)
464-
topk_scales = topk_weights.view(-1).index_select(
465-
0, self.sorted_topk_ids).unsqueeze(-1)
466-
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
467-
group_list_type = 0
468-
return {
469-
"group_list_type": group_list_type,
470-
"hidden_states": sorted_hidden_states,
471-
"group_list": group_list,
472-
"topk_scales": topk_scales
473-
}
474-
475-
def token_combine(self,
476-
hidden_states: torch.Tensor,
477-
context_metadata: dict,
478-
bias: torch.Tensor = None):
479-
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
480-
torch.int32)
481-
unsorted_hidden_states = hidden_states.index_select(
482-
0, unsorted_topk_ids)
483-
final_hidden_states = unsorted_hidden_states.reshape(
484-
self.bsz, self.top_k // self.ep_size, -1).sum(1)
485-
return final_hidden_states
486-
487-
488425
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
489426
"""
490427
The implementation of the AlltoAll-based token dispatcher, which handles token

vllm_ascend/worker/model_runner_v1.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,7 +1418,6 @@ def _select_moe_comm_method(self,
14181418
quant_type = getattr(
14191419
self.vllm_config.model_config.hf_config, 'moe_quantize',
14201420
getattr(self.vllm_config.model_config.hf_config, 'quantize', None))
1421-
model_type = self.vllm_config.model_config.hf_config.model_type
14221421

14231422
if not self.parallel_config.enable_expert_parallel:
14241423
moe_comm_type = MoECommType.ALLGATHER
@@ -1445,10 +1444,6 @@ def _select_moe_comm_method(self,
14451444
else:
14461445
raise ValueError(f"Unsupported soc_version: {soc_version}")
14471446

1448-
# PanguProMoE only supports allgather
1449-
if model_type == "PanguProMoE":
1450-
moe_comm_type = MoECommType.ALLGATHER
1451-
14521447
if is_global_first_rank():
14531448
logger.debug(f"num_tokens: {num_tokens}, "
14541449
f"moe_comm_type: {moe_comm_type}")

0 commit comments

Comments
 (0)