Skip to content

Commit 49415eb

Browse files
committed
Remove PanguProMoEV1 related code
Signed-off-by: weichen <[email protected]>
1 parent 4ed2951 commit 49415eb

File tree

3 files changed

+5
-78
lines changed

3 files changed

+5
-78
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ def select_moe_comm_method(num_tokens: int,
271271
quant_type = getattr(
272272
vllm_config.model_config.hf_config, 'moe_quantize',
273273
getattr(vllm_config.model_config.hf_config, 'quantize', None))
274-
model_type = vllm_config.model_config.hf_config.model_type
275274

276275
if not vllm_config.parallel_config.enable_expert_parallel:
277276
moe_comm_type = MoECommType.ALLGATHER
@@ -300,7 +299,4 @@ def select_moe_comm_method(num_tokens: int,
300299
raise ValueError(f"Unsupported soc_version: {soc_version}")
301300
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
302301
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
303-
# PanguProMoE only supports allgather
304-
if model_type == "PanguProMoE":
305-
moe_comm_type = MoECommType.ALLGATHER
306302
return moe_comm_type

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
PrepareAndFinalizeWithMC2, QuantType)
3131
from vllm_ascend.ops.fused_moe.token_dispatcher import (
3232
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
33-
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
33+
TokenDispatcherWithMC2)
3434

3535
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
3636

@@ -198,16 +198,10 @@ class AllGatherCommImpl(MoECommMethod):
198198
"""
199199

200200
def _get_token_dispatcher(self):
201-
if self.model_type == "PanguProMoE":
202-
return TokenDispatcherWithMoge(
203-
top_k=self.moe_config.experts_per_token,
204-
num_experts=self.moe_config.num_experts,
205-
num_local_experts=self.moe_config.num_local_experts)
206-
else:
207-
return TokenDispatcherWithAllGather(
208-
top_k=self.moe_config.experts_per_token,
209-
num_experts=self.moe_config.num_experts,
210-
num_local_experts=self.moe_config.num_local_experts)
201+
return TokenDispatcherWithAllGather(
202+
top_k=self.moe_config.experts_per_token,
203+
num_experts=self.moe_config.num_experts,
204+
num_local_experts=self.moe_config.num_local_experts)
211205

212206
def _get_prepare_finalize(self):
213207
return PrepareAndFinalizeWithAllGather(self.moe_config)

vllm_ascend/ops/fused_moe/token_dispatcher.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -422,69 +422,6 @@ def token_combine(self,
422422
return final_hidden_states
423423

424424

425-
# mypy: disable-error-code="override"
426-
class TokenDispatcherWithMoge(MoETokenDispatcher):
427-
428-
def __init__(self, **kwargs):
429-
super().__init__(**kwargs)
430-
self.apply_router_weight_on_input = False
431-
self.local_num_experts = self.num_experts // self.ep_size
432-
self.local_num_group = self.top_k // self.ep_size
433-
self.bsz = None
434-
435-
def token_dispatch(self,
436-
hidden_states: torch.Tensor,
437-
topk_weights: torch.Tensor,
438-
topk_ids: torch.Tensor,
439-
expert_map: Optional[torch.Tensor] = None,
440-
log2phy: Optional[torch.Tensor] = None,
441-
global_redundant_expert_num: int = 0,
442-
shared_experts: Optional[Any] = None,
443-
quantized_x_for_share: Optional[Any] = None,
444-
dynamic_scale_for_share: Optional[Any] = None,
445-
mc2_mask: Optional[torch.Tensor] = None,
446-
apply_router_weight_on_input: bool = False,
447-
with_quant: bool = False,
448-
dynamic_eplb: bool = False,
449-
pertoken_scale: Optional[torch.Tensor] = None):
450-
self.bsz, _ = hidden_states.shape
451-
flatten_topk_ids = topk_ids.view(-1)
452-
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
453-
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
454-
sorted_hidden_states = hidden_states.index_select(
455-
0, self.sorted_topk_ids // self.local_num_group)
456-
457-
experts_id = torch.arange(0,
458-
self.local_num_experts,
459-
dtype=topk_ids.dtype,
460-
device=topk_ids.device)
461-
num_tokens_per_expert = (
462-
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
463-
torch.float32).sum(0)
464-
topk_scales = topk_weights.view(-1).index_select(
465-
0, self.sorted_topk_ids).unsqueeze(-1)
466-
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
467-
group_list_type = 0
468-
return {
469-
"group_list_type": group_list_type,
470-
"hidden_states": sorted_hidden_states,
471-
"group_list": group_list,
472-
"topk_scales": topk_scales
473-
}
474-
475-
def token_combine(self,
476-
hidden_states: torch.Tensor,
477-
context_metadata: dict,
478-
bias: torch.Tensor = None):
479-
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
480-
torch.int32)
481-
unsorted_hidden_states = hidden_states.index_select(
482-
0, unsorted_topk_ids)
483-
final_hidden_states = unsorted_hidden_states.reshape(
484-
self.bsz, self.top_k // self.ep_size, -1).sum(1)
485-
return final_hidden_states
486-
487-
488425
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
489426
"""
490427
The implementation of the AlltoAll-based token dispatcher, which handles token

0 commit comments

Comments
 (0)