1616from vllm .model_executor .model_loader .weight_utils import (
1717 default_weight_loader , maybe_remap_kv_scale_name )
1818from collections .abc import Callable , Iterable
19+ # from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
1920from vllm .model_executor .layers .fused_moe .shared_fused_moe import SharedFusedMoE
20- # from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
2121from vllm .model_executor .models .deepseek_v2 import DeepseekV2ForCausalLM , get_spec_layer_idx_from_weight_name , DeepseekV2MLP , DeepseekV2MoE
2222from vllm .model_executor .models .utils import is_pp_missing_parameter
2323from vllm .config import ParallelConfig
@@ -37,14 +37,14 @@ def __init__(
3737 quant_config : QuantizationConfig | None = None ,
3838 prefix : str = "" ,
3939 ):
40- nn . Module . __init__ (self )
40+ super (). __init__ ()
4141 self .tp_size = get_tensor_model_parallel_world_size ()
4242 self .tp_rank = get_tensor_model_parallel_rank ()
4343
44- self .routed_scaling_factor = config . routed_scaling_factor
44+ self .routed_scaling_factor = getattr ( config , " routed_scaling_factor" , 1.0 )
4545
4646 self .ep_group = get_ep_group ().device_group
47- self .ep_rank = self . ep_group . rank ()
47+ self .ep_rank = get_ep_group (). rank_in_group
4848 self .ep_size = self .ep_group .size ()
4949 self .n_routed_experts : int = config .n_routed_experts
5050 self .n_shared_experts : int = config .n_shared_experts
@@ -64,7 +64,7 @@ def __init__(
6464 quant_config = None ,
6565 prefix = f"{ prefix } .gate" ,
6666 )
67- if config . topk_method == "noaux_tc" :
67+ if getattr ( config , " topk_method" , None ) == "noaux_tc" :
6868 self .gate .e_score_correction_bias = nn .Parameter (
6969 torch .empty (config .n_routed_experts , dtype = torch .float32 )
7070 )
@@ -115,10 +115,10 @@ def __init__(
115115 renormalize = config .norm_topk_prob ,
116116 quant_config = quant_config ,
117117 use_grouped_topk = True ,
118- num_expert_group = config . n_group ,
119- topk_group = config . topk_group ,
118+ num_expert_group = getattr ( config , " n_group" , 1 ) ,
119+ topk_group = getattr ( config , " topk_group" , 1 ) ,
120120 prefix = f"{ prefix } .experts" ,
121- scoring_func = config . scoring_func ,
121+ scoring_func = getattr ( config , " scoring_func" , "softmax" ) ,
122122 # we do scaling outside, set factor to 1.0 to avoid double mul
123123 # aiter applies routed_scaling_factor internally
124124 routed_scaling_factor = 1.0
@@ -130,34 +130,42 @@ def __init__(
130130 is_sequence_parallel = self .is_sequence_parallel ,
131131 )
132132
133-
134133 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
135134 num_tokens , hidden_dim = hidden_states .shape
136135 hidden_states = hidden_states .view (- 1 , hidden_dim )
136+
137137 # Chunk the hidden states so they aren't replicated across TP ranks.
138138 # This avoids duplicate computation in self.experts.
139139 # TODO: We can replace the all_reduce at the end of attn with a
140140 # reduce_scatter instead of chunking here.
141141 if self .is_sequence_parallel :
142142 hidden_states = sequence_parallel_chunk (hidden_states )
143143
144- router_logits , _ = self .gate (hidden_states )
145- fused_moe_out = self .experts (
146- hidden_states = hidden_states , router_logits = router_logits
147- )
148- ascend_config = get_ascend_config ()
144+ if self .experts .is_internal_router :
145+ # In this case, the gate/router runs inside the FusedMoE class
146+ fused_moe_out = self .experts (
147+ hidden_states = hidden_states , router_logits = hidden_states
148+ )
149+ else :
150+ # router_logits: (num_tokens, n_experts)
151+ router_logits , _ = self .gate (hidden_states )
152+ fused_moe_out = self .experts (
153+ hidden_states = hidden_states , router_logits = router_logits
154+ )
155+
149156 shared_output , final_hidden_states = fused_moe_out
150157 if self .shared_experts is None :
151158 assert shared_output is None
152-
159+
153160 # Fix FP16 overflow
154161 # See DeepseekV2DecoderLayer for more details.
155162 if hidden_states .dtype != torch .float16 :
156- final_hidden_states *= self .routed_scaling_factor
163+ if not self .is_rocm_aiter_moe_enabled :
164+ final_hidden_states *= self .routed_scaling_factor
157165 elif self .shared_experts is not None :
158166 assert shared_output is not None
159167 shared_output *= 1.0 / self .routed_scaling_factor
160-
168+
161169 if self .shared_experts is not None :
162170 assert shared_output is not None
163171 final_hidden_states += shared_output
@@ -171,20 +179,35 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
171179 final_hidden_states = self .experts .maybe_all_reduce_tensor_model_parallel (
172180 final_hidden_states
173181 )
182+
174183 return final_hidden_states .view (num_tokens , hidden_dim )
175184
176185class CustomDeepseekV2ForCausalLM (DeepseekV2ForCausalLM ):
177186 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
178187 ascend_config = get_ascend_config ()
179188 stacked_params_mapping = [
189+ # (param_name, shard_name, shard_id)
180190 ("gate_up_proj" , "gate_proj" , 0 ),
181191 ("gate_up_proj" , "up_proj" , 1 ),
192+ ]
193+ mla_params_mapping = [
182194 ("fused_qkv_a_proj" , "q_a_proj" , 0 ),
183195 ("fused_qkv_a_proj" , "kv_a_proj_with_mqa" , 1 ),
184196 ]
197+ mha_params_mapping = [
198+ ("qkv_proj" , "q_proj" , "q" ),
199+ ("qkv_proj" , "k_proj" , "k" ),
200+ ("qkv_proj" , "v_proj" , "v" ),
201+ ]
202+ if self .use_mha :
203+ stacked_params_mapping .extend (mha_params_mapping )
204+ else :
205+ stacked_params_mapping .extend (mla_params_mapping )
185206
186207 mix_placement = getattr (ascend_config , "mix_placement" , False )
187208
209+ # Params for weights, fp8 weight scales, fp8 activation scales
210+ # (param_name, weight_name, expert_id, shard_id)
188211
189212 expert_params_mapping = SharedFusedMoE .make_expert_params_mapping (
190213 ckpt_gate_proj_name = "gate_proj" ,
@@ -207,77 +230,119 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
207230
208231 spec_layer = get_spec_layer_idx_from_weight_name (self .config , name )
209232 if spec_layer is not None :
210- continue
233+ continue # skip spec decode layers for main model
211234
212- is_fuse_shared_experts_layer = (
213- mix_placement
214- and ("mlp.shared_experts" in name )
235+ is_fusion_moe_shared_experts_layer = (
236+ mix_placement and ("mlp.shared_experts" in name )
215237 )
216238
217239 for param_name , weight_name , shard_id in stacked_params_mapping :
240+ # Skip non-stacked layers and experts (experts handled below).
218241 if weight_name not in name :
219242 continue
243+ # We have mlp.experts[0].gate_proj in the checkpoint.
244+ # Since we handle the experts below in expert_params_mapping,
245+ # we need to skip here BEFORE we update the name, otherwise
246+ # name will be updated to mlp.experts[0].gate_up_proj, which
247+ # will then be updated below in expert_params_mapping
248+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
220249 if ("mlp.experts." in name ) and name not in params_dict :
221250 continue
222- if is_fuse_shared_experts_layer :
251+ if is_fusion_moe_shared_experts_layer :
223252 continue
224253 name_mapped = name .replace (weight_name , param_name )
225254
226- if (param_name == "fused_qkv_a_proj" ) and name_mapped not in params_dict :
255+ # QKV fusion is optional, fall back to normal
256+ # weight loading if it's not enabled
257+ # if go with fusion option, then update name
258+ if (
259+ param_name == "fused_qkv_a_proj"
260+ ) and name_mapped not in params_dict :
227261 continue
228262 else :
229263 name = name_mapped
264+ # Skip loading extra bias for GPTQ models.
230265 if name .endswith (".bias" ) and name not in params_dict :
231266 continue
267+
232268 if is_pp_missing_parameter (name , self ):
233269 continue
234- if name not in params_dict .keys ():
235- continue
236270
237271 param = params_dict [name ]
238272 weight_loader = param .weight_loader
239273 weight_loader (param , loaded_weight , shard_id )
240274 break
241275 else :
242276 is_expert_weight = False
277+
278+ # Special handling: when AITER fusion_shared_experts is enabled,
279+ # checkpoints may provide a single widened shared_experts tensor
280+ # without explicit expert indices
281+ # (e.g. ...mlp.shared_experts.gate_proj.weight).
282+ # For models with multiple shared experts, split that tensor
283+ # evenly into per-shared-expert slices and load them into
284+ # appended expert slots mlp.experts.{n_routed_experts + j}.*
285+ # accordingly.
243286 num_chunks = 1
244- if is_fuse_shared_experts_layer :
287+ if is_fusion_moe_shared_experts_layer :
245288 num_chunks = getattr (self .config , "n_shared_experts" , 1 ) or 1
289+ # Determine split axis based on op type
290+ # gate/up: ColumnParallel → split along dim 0
291+ # down: RowParallel → split along dim 1
246292 split_dim = 1 if "down_proj.weight" in name else 0
247293 total = loaded_weight .shape [split_dim ]
248294 assert total % num_chunks == 0 , (
249- f"Shared expert weight dim { total } not divisible by num_chunks { num_chunks } "
295+ f"Shared expert weight dim { total } "
296+ f"not divisible by num_chunks { num_chunks } "
250297 )
251298 chunk_size = total // num_chunks
252299
253300 for j in range (num_chunks ):
254301 chunk_name = name
255302 weight_to_load = loaded_weight
256303
257- if is_fuse_shared_experts_layer :
304+ if is_fusion_moe_shared_experts_layer :
258305 if split_dim == 0 :
259- weight_to_load = loaded_weight [j * chunk_size : (j + 1 ) * chunk_size , :]
306+ weight_to_load = loaded_weight [
307+ j * chunk_size : (j + 1 ) * chunk_size , :
308+ ]
260309 else :
261- weight_to_load = loaded_weight [:, j * chunk_size : (j + 1 ) * chunk_size ]
310+ weight_to_load = loaded_weight [
311+ :, j * chunk_size : (j + 1 ) * chunk_size
312+ ]
313+ # Synthesize an expert-style name so expert mapping
314+ # can route it
262315 chunk_name = name .replace (
263316 "mlp.shared_experts" ,
264317 f"mlp.experts.{ self .config .n_routed_experts + j } " ,
265318 )
266319
320+ # Use expert_params_mapping to locate the destination
321+ # param and delegate to its expert-aware weight_loader
322+ # with expert_id.
267323 for mapping in expert_params_mapping :
268324 param_name , weight_name , expert_id , shard_id = mapping
269325 if weight_name not in chunk_name :
270326 continue
271327
328+ # Anyway, this is an expert weight and should not be
329+ # attempted to load as other weights later
272330 is_expert_weight = True
331+
332+ # Do not modify `name` since the loop may continue here
333+ # Instead, create a new variable
273334 name_mapped = chunk_name .replace (weight_name , param_name )
274335
275336 if is_pp_missing_parameter (name_mapped , self ):
276337 continue
277- if name_mapped not in params_dict .keys ():
278- continue
338+
279339 param = params_dict [name_mapped ]
280- weight_loader = typing .cast (Callable [..., bool ], param .weight_loader )
340+ # We should ask the weight loader to return success or
341+ # not here since otherwise we may skip experts with
342+ # other available replicas.
343+ weight_loader = typing .cast (
344+ Callable [..., bool ], param .weight_loader
345+ )
281346 success = weight_loader (
282347 param ,
283348 weight_to_load ,
@@ -287,32 +352,39 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
287352 return_success = True ,
288353 )
289354 if success :
290- if not is_fuse_shared_experts_layer :
355+ if not is_fusion_moe_shared_experts_layer :
291356 name = name_mapped
292357 else :
293- loaded_params .add (name_mapped )
358+ loaded_params .add (name_mapped )
294359 break
295360 else :
296361 if is_expert_weight :
362+ # We've checked that this is an expert weight
363+ # However it's not mapped locally to this rank
364+ # So we simply skip it
297365 continue
366+
367+ # Skip loading extra bias for GPTQ models.
298368 if name .endswith (".bias" ) and name not in params_dict :
299369 continue
370+
371+ # Remapping the name of FP8 kv-scale.
300372 name = maybe_remap_kv_scale_name (name , params_dict )
301373 if name is None :
302374 continue
375+
303376 if is_pp_missing_parameter (name , self ):
304377 continue
305- if name not in params_dict .keys ():
306- continue
307378
308379 param = params_dict [name ]
309- weight_loader = getattr (param , "weight_loader" ,
310- default_weight_loader )
380+ weight_loader = getattr (
381+ param , "weight_loader" , default_weight_loader
382+ )
311383 weight_loader (param , loaded_weight )
312- if not is_fuse_shared_experts_layer :
384+ if not is_fusion_moe_shared_experts_layer :
313385 loaded_params .add (name )
314- return loaded_params
315386
387+ return loaded_params
316388
317389vllm .model_executor .models .deepseek_v2 .DeepseekV2MoE = AscendDeepseekV2MoE
318390DeepseekV2ForCausalLM .load_weights = CustomDeepseekV2ForCausalLM .load_weights
0 commit comments