@@ -172,10 +172,10 @@ def __init__(self, *args, **kwargs):
172172 self .moe_config .dp_group = get_dp_group ()
173173 self .moe_config .ep_group = get_ep_group ()
174174 self .moe_config .mc2_group = get_mc2_group ()
175- ascend_config = get_ascend_config ()
176- self .dynamic_eplb = ascend_config .dynamic_eplb or ascend_config .expert_map_record_path
177- self .expert_map_path = ascend_config .expert_map_path
178- self .global_redundant_expert_num = ascend_config .init_redundancy_expert
175+ self . ascend_config = get_ascend_config ()
176+ self .dynamic_eplb = self . ascend_config .dynamic_eplb or self . ascend_config .expert_map_record_path
177+ self .expert_map_path = self . ascend_config .expert_map_path
178+ self .global_redundant_expert_num = self . ascend_config .init_redundancy_expert
179179 self .global_num_experts = num_experts + self .global_redundant_expert_num
180180 if self .custom_routing_function is None and self .e_score_correction_bias is not None :
181181 vllm_config = get_current_vllm_config ()
@@ -195,8 +195,8 @@ def __init__(self, *args, **kwargs):
195195 self .expert_load_balancer = ExpertLoadBalancer (
196196 self .expert_map_path , num_experts )
197197 self .expert_load_balancer .check_expert_map_tensor ()
198- self .global_redundant_expert_num = (
199- self .expert_load_balancer .get_global_redundant_expert_num ())
198+ # self.global_redundant_expert_num = (
199+ # self.expert_load_balancer.get_global_redundant_expert_num())
200200 self .global_num_experts = num_experts + self .global_redundant_expert_num
201201 try :
202202 self .local_num_experts , self .expert_map = (
@@ -254,7 +254,7 @@ def __init__(self, *args, **kwargs):
254254 moe_quant_params ["intermediate_size_full" ] = intermediate_size
255255 self .quant_method .create_weights (layer = self , ** moe_quant_params )
256256
257- self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
257+ self .enable_shared_expert_dp = self . ascend_config .enable_shared_expert_dp
258258
259259 setup_moe_comm_method (self .moe_config )
260260 self .quant_type = self ._get_quant_type ()
@@ -460,8 +460,8 @@ def __init__(
460460 self ._shared_experts = shared_experts
461461 self .use_overlapped = use_overlapped
462462 self .shared_expert_stream = None
463- ascend_config = get_ascend_config ()
464- self .multistream_overlap_shared_expert = ascend_config .multistream_overlap_shared_expert
463+ self . ascend_config = get_ascend_config ()
464+ self .multistream_overlap_shared_expert = self . ascend_config .multistream_overlap_shared_expert
465465 if enable_sp ():
466466 logger .info_once (
467467 "Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -489,11 +489,19 @@ def forward(
489489 hidden_states : torch .Tensor ,
490490 router_logits : torch .Tensor ,
491491 ) -> tuple [torch .Tensor , torch .Tensor ]:
492- shared_out , fused_out = AscendFusedMoE .forward (
493- self ,
494- hidden_states = hidden_states ,
495- router_logits = router_logits ,
496- )
492+ if self ._shared_experts is None :
493+ fused_out = AscendFusedMoE .forward (
494+ self ,
495+ hidden_states = hidden_states ,
496+ router_logits = router_logits ,
497+ )
498+ shared_out = None
499+ else :
500+ shared_out , fused_out = AscendFusedMoE .forward (
501+ self ,
502+ hidden_states = hidden_states ,
503+ router_logits = router_logits ,
504+ )
497505 return shared_out , fused_out
498506
499507 def forward_impl (self , hidden_states : torch .Tensor ,
@@ -507,7 +515,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
507515 # Use a separate stream to run shared experts.
508516 # Note that currently we only support calculations in separate streams with aclgraph.
509517 # Communication operations in another stream might cause unknown errors.
510- shared_out = self ._shared_experts (hidden_states )
518+ if self ._shared_experts is None :
519+ shared_out = None
520+ else :
521+ shared_out = self ._shared_experts (hidden_states )
511522
512523 fused_output = AscendFusedMoE .forward_impl (
513524 self ,
@@ -521,7 +532,10 @@ def forward_impl(self, hidden_states: torch.Tensor,
521532 # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
522533 forward_context = get_forward_context ()
523534 moe_comm_type = forward_context .moe_comm_type
524- if moe_comm_type in {MoECommType .ALLTOALL , MoECommType .MC2 , MoECommType . FUSED_ALLTOALL } \
525- and not shared_expert_dp_enabled ():
535+ if moe_comm_type in {MoECommType .ALLTOALL , MoECommType .MC2 } \
536+ and not shared_expert_dp_enabled () and shared_out is not None :
526537 shared_out = tensor_model_parallel_all_reduce (shared_out )
527- return shared_out , fused_output
538+ if shared_out is None :
539+ return fused_output
540+ else :
541+ return shared_out , fused_output
0 commit comments