@@ -118,18 +118,15 @@ def get_dispatch_mc2_kwargs(
118118 mc2_mask : torch .Tensor ,
119119 global_redundant_expert_num : int = 0 ,
120120 ):
121- if self .with_quant :
122- quant_mode = 2
123- moe_expert_num = len (expert_map )
124- else :
125- quant_mode = 0
126- moe_expert_num = len (expert_map )
121+ quant_mode = 2 if self .with_quant else 0
122+ self .physics_num_experts = len (expert_map ) + global_redundant_expert_num
123+
127124 kwargs_mc2 = {
128125 "x" : hidden_states ,
129126 "expert_ids" : topk_ids ,
130127 "expert_shard_type" : 0 ,
131128 "shared_expert_rank_num" : 0 ,
132- "moe_expert_num" : moe_expert_num ,
129+ "moe_expert_num" : self . physics_num_experts ,
133130 "global_bs" : 0 ,
134131 "expert_token_nums_type" : 0 ,
135132 }
@@ -247,15 +244,14 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
247244 expand_scales = context_metadata ["expand_scales" ]
248245
249246 assert expert_map is not None
250- moe_expert_num = len (expert_map )
251247
252248 kwargs_mc2 = {
253249 "expand_x" : hidden_states ,
254250 "expert_ids" : topk_ids ,
255251 "expert_scales" : topk_weights .to (torch .float32 ),
256252 "expert_shard_type" : 0 ,
257253 "shared_expert_rank_num" : 0 ,
258- "moe_expert_num" : moe_expert_num ,
254+ "moe_expert_num" : self . physics_num_experts ,
259255 "global_bs" : 0 ,
260256 }
261257
@@ -360,7 +356,7 @@ def token_dispatch(self,
360356 hidden_states = hidden_states * \
361357 topk_weights .to (hidden_states .dtype )
362358 if expert_map is not None :
363- global_num_experts = len (expert_map )
359+ global_num_experts = len (expert_map ) + global_redundant_expert_num
364360 mask = (expert_map [topk_ids ] != - 1 )
365361 topk_weights = topk_weights * mask
366362 first_expert_idx = get_ep_group (
0 commit comments