1313from torch .nn .parameter import UninitializedParameter
1414
1515import vllm .envs as envs
16- from vllm .config import get_current_vllm_config , ParallelConfig
16+ from vllm .config import ParallelConfig , get_current_vllm_config
1717from vllm .distributed import (get_dp_group , get_ep_group ,
1818 get_tensor_model_parallel_rank ,
1919 get_tensor_model_parallel_world_size ,
@@ -322,6 +322,7 @@ def __init__(self, moe: MoEConfig):
322322 super ().__init__ ()
323323 self .fused_experts = fused_experts
324324 self .moe = moe
325+
325326 self .rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled ()
326327 if self .rocm_aiter_moe_enabled :
327328 from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@@ -501,6 +502,8 @@ def forward_cuda(
501502 indices_type = torch .uint32 if self .moe .use_pplx_kernels else None )
502503
503504 if self .rocm_aiter_moe_enabled :
505+ assert not apply_router_weight_on_input
506+ assert expert_map is None
504507 return self .rocm_aiter_fused_experts (
505508 hidden_states = x ,
506509 w1 = layer .w13_weight ,
@@ -510,8 +513,8 @@ def forward_cuda(
510513 activation = activation ,
511514 apply_router_weight_on_input = apply_router_weight_on_input )
512515 else :
513- return fused_experts (
514- a1 = x ,
516+ return self . fused_experts (
517+ hidden_states = x ,
515518 w1 = layer .w13_weight ,
516519 w2 = layer .w2_weight ,
517520 topk_weights = topk_weights ,
@@ -1191,8 +1194,7 @@ def select_experts(hidden_states: torch.Tensor,
11911194 scoring_func : str = "softmax" ,
11921195 e_score_correction_bias : Optional [torch .Tensor ] = None ,
11931196 indices_type : Optional [torch .dtype ] = None ):
1194- from vllm .model_executor .layers .fused_moe .fused_moe import (
1195- fused_topk , grouped_topk )
1197+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
11961198
11971199 # DeekSeekv2 uses grouped_top_k
11981200 if use_grouped_topk :
@@ -1228,24 +1230,6 @@ def select_experts(hidden_states: torch.Tensor,
12281230
12291231 return topk_weights , topk_ids
12301232
1231- def naive_multicast (self , x : torch .Tensor ,
1232- cu_tokens_across_dp_cpu : torch .Tensor ):
1233- assert (len (x .shape ) == 2 )
1234- buffer = torch .empty ((cu_tokens_across_dp_cpu [- 1 ], x .size (1 )),
1235- device = x .device ,
1236- dtype = x .dtype )
1237-
1238- start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
1239- self .dp_rank - 1 ]
1240- end = cu_tokens_across_dp_cpu [self .dp_rank ]
1241- buffer [start :end , :].copy_ (x )
1242- for idx in range (get_dp_group ().world_size ):
1243- start = 0 if idx == 0 else cu_tokens_across_dp_cpu [idx - 1 ]
1244- end = cu_tokens_across_dp_cpu [idx ]
1245- get_dp_group ().broadcast (buffer [start :end , :], idx )
1246-
1247- return buffer
1248-
12491233 def must_reduce_shared_expert_outputs (self ) -> bool :
12501234 """
12511235 The shared_experts are typically computed using the RowParallelLinear
0 commit comments