Skip to content

Commit 3e0187a

Browse files
Refactor DeepseekV2 to improve weight loading
Signed-off-by: Mercykid-bash <[email protected]>
1 parent 7555a6c commit 3e0187a

File tree

1 file changed

+113
-41
lines changed

1 file changed

+113
-41
lines changed

vllm_ascend/patch/worker/patch_deepseekv3.py

Lines changed: 113 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from vllm.model_executor.model_loader.weight_utils import (
1717
default_weight_loader, maybe_remap_kv_scale_name)
1818
from collections.abc import Callable, Iterable
19+
# from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
1920
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
20-
# from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
2121
from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, get_spec_layer_idx_from_weight_name, DeepseekV2MLP, DeepseekV2MoE
2222
from vllm.model_executor.models.utils import is_pp_missing_parameter
2323
from vllm.config import ParallelConfig
@@ -37,14 +37,14 @@ def __init__(
3737
quant_config: QuantizationConfig | None = None,
3838
prefix: str = "",
3939
):
40-
nn.Module.__init__(self)
40+
super().__init__()
4141
self.tp_size = get_tensor_model_parallel_world_size()
4242
self.tp_rank = get_tensor_model_parallel_rank()
4343

44-
self.routed_scaling_factor = config.routed_scaling_factor
44+
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
4545

4646
self.ep_group = get_ep_group().device_group
47-
self.ep_rank = self.ep_group.rank()
47+
self.ep_rank = get_ep_group().rank_in_group
4848
self.ep_size = self.ep_group.size()
4949
self.n_routed_experts: int = config.n_routed_experts
5050
self.n_shared_experts: int = config.n_shared_experts
@@ -64,7 +64,7 @@ def __init__(
6464
quant_config=None,
6565
prefix=f"{prefix}.gate",
6666
)
67-
if config.topk_method == "noaux_tc":
67+
if getattr(config, "topk_method", None) == "noaux_tc":
6868
self.gate.e_score_correction_bias = nn.Parameter(
6969
torch.empty(config.n_routed_experts, dtype=torch.float32)
7070
)
@@ -115,10 +115,10 @@ def __init__(
115115
renormalize=config.norm_topk_prob,
116116
quant_config=quant_config,
117117
use_grouped_topk=True,
118-
num_expert_group=config.n_group,
119-
topk_group=config.topk_group,
118+
num_expert_group=getattr(config, "n_group", 1),
119+
topk_group=getattr(config, "topk_group", 1),
120120
prefix=f"{prefix}.experts",
121-
scoring_func=config.scoring_func,
121+
scoring_func=getattr(config, "scoring_func", "softmax"),
122122
# we do scaling outside, set factor to 1.0 to avoid double mul
123123
# aiter applies routed_scaling_factor internally
124124
routed_scaling_factor=1.0
@@ -130,34 +130,42 @@ def __init__(
130130
is_sequence_parallel=self.is_sequence_parallel,
131131
)
132132

133-
134133
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
135134
num_tokens, hidden_dim = hidden_states.shape
136135
hidden_states = hidden_states.view(-1, hidden_dim)
136+
137137
# Chunk the hidden states so they aren't replicated across TP ranks.
138138
# This avoids duplicate computation in self.experts.
139139
# TODO: We can replace the all_reduce at the end of attn with a
140140
# reduce_scatter instead of chunking here.
141141
if self.is_sequence_parallel:
142142
hidden_states = sequence_parallel_chunk(hidden_states)
143143

144-
router_logits, _ = self.gate(hidden_states)
145-
fused_moe_out = self.experts(
146-
hidden_states=hidden_states, router_logits=router_logits
147-
)
148-
ascend_config = get_ascend_config()
144+
if self.experts.is_internal_router:
145+
# In this case, the gate/router runs inside the FusedMoE class
146+
fused_moe_out = self.experts(
147+
hidden_states=hidden_states, router_logits=hidden_states
148+
)
149+
else:
150+
# router_logits: (num_tokens, n_experts)
151+
router_logits, _ = self.gate(hidden_states)
152+
fused_moe_out = self.experts(
153+
hidden_states=hidden_states, router_logits=router_logits
154+
)
155+
149156
shared_output, final_hidden_states = fused_moe_out
150157
if self.shared_experts is None:
151158
assert shared_output is None
152-
159+
153160
# Fix FP16 overflow
154161
# See DeepseekV2DecoderLayer for more details.
155162
if hidden_states.dtype != torch.float16:
156-
final_hidden_states *= self.routed_scaling_factor
163+
if not self.is_rocm_aiter_moe_enabled:
164+
final_hidden_states *= self.routed_scaling_factor
157165
elif self.shared_experts is not None:
158166
assert shared_output is not None
159167
shared_output *= 1.0 / self.routed_scaling_factor
160-
168+
161169
if self.shared_experts is not None:
162170
assert shared_output is not None
163171
final_hidden_states += shared_output
@@ -171,20 +179,35 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
171179
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
172180
final_hidden_states
173181
)
182+
174183
return final_hidden_states.view(num_tokens, hidden_dim)
175184

176185
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
177186
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
178187
ascend_config = get_ascend_config()
179188
stacked_params_mapping = [
189+
# (param_name, shard_name, shard_id)
180190
("gate_up_proj", "gate_proj", 0),
181191
("gate_up_proj", "up_proj", 1),
192+
]
193+
mla_params_mapping = [
182194
("fused_qkv_a_proj", "q_a_proj", 0),
183195
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
184196
]
197+
mha_params_mapping = [
198+
("qkv_proj", "q_proj", "q"),
199+
("qkv_proj", "k_proj", "k"),
200+
("qkv_proj", "v_proj", "v"),
201+
]
202+
if self.use_mha:
203+
stacked_params_mapping.extend(mha_params_mapping)
204+
else:
205+
stacked_params_mapping.extend(mla_params_mapping)
185206

186207
mix_placement = getattr(ascend_config, "mix_placement", False)
187208

209+
# Params for weights, fp8 weight scales, fp8 activation scales
210+
# (param_name, weight_name, expert_id, shard_id)
188211

189212
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
190213
ckpt_gate_proj_name="gate_proj",
@@ -207,77 +230,119 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
207230

208231
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
209232
if spec_layer is not None:
210-
continue
233+
continue # skip spec decode layers for main model
211234

212-
is_fuse_shared_experts_layer = (
213-
mix_placement
214-
and ("mlp.shared_experts" in name)
235+
is_fusion_moe_shared_experts_layer = (
236+
mix_placement and ("mlp.shared_experts" in name)
215237
)
216238

217239
for param_name, weight_name, shard_id in stacked_params_mapping:
240+
# Skip non-stacked layers and experts (experts handled below).
218241
if weight_name not in name:
219242
continue
243+
# We have mlp.experts[0].gate_proj in the checkpoint.
244+
# Since we handle the experts below in expert_params_mapping,
245+
# we need to skip here BEFORE we update the name, otherwise
246+
# name will be updated to mlp.experts[0].gate_up_proj, which
247+
# will then be updated below in expert_params_mapping
248+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
220249
if ("mlp.experts." in name) and name not in params_dict:
221250
continue
222-
if is_fuse_shared_experts_layer:
251+
if is_fusion_moe_shared_experts_layer:
223252
continue
224253
name_mapped = name.replace(weight_name, param_name)
225254

226-
if (param_name == "fused_qkv_a_proj") and name_mapped not in params_dict:
255+
# QKV fusion is optional, fall back to normal
256+
# weight loading if it's not enabled
257+
# if go with fusion option, then update name
258+
if (
259+
param_name == "fused_qkv_a_proj"
260+
) and name_mapped not in params_dict:
227261
continue
228262
else:
229263
name = name_mapped
264+
# Skip loading extra bias for GPTQ models.
230265
if name.endswith(".bias") and name not in params_dict:
231266
continue
267+
232268
if is_pp_missing_parameter(name, self):
233269
continue
234-
if name not in params_dict.keys():
235-
continue
236270

237271
param = params_dict[name]
238272
weight_loader = param.weight_loader
239273
weight_loader(param, loaded_weight, shard_id)
240274
break
241275
else:
242276
is_expert_weight = False
277+
278+
# Special handling: when AITER fusion_shared_experts is enabled,
279+
# checkpoints may provide a single widened shared_experts tensor
280+
# without explicit expert indices
281+
# (e.g. ...mlp.shared_experts.gate_proj.weight).
282+
# For models with multiple shared experts, split that tensor
283+
# evenly into per-shared-expert slices and load them into
284+
# appended expert slots mlp.experts.{n_routed_experts + j}.*
285+
# accordingly.
243286
num_chunks = 1
244-
if is_fuse_shared_experts_layer:
287+
if is_fusion_moe_shared_experts_layer:
245288
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
289+
# Determine split axis based on op type
290+
# gate/up: ColumnParallel → split along dim 0
291+
# down: RowParallel → split along dim 1
246292
split_dim = 1 if "down_proj.weight" in name else 0
247293
total = loaded_weight.shape[split_dim]
248294
assert total % num_chunks == 0, (
249-
f"Shared expert weight dim {total} not divisible by num_chunks {num_chunks}"
295+
f"Shared expert weight dim {total} "
296+
f"not divisible by num_chunks {num_chunks}"
250297
)
251298
chunk_size = total // num_chunks
252299

253300
for j in range(num_chunks):
254301
chunk_name = name
255302
weight_to_load = loaded_weight
256303

257-
if is_fuse_shared_experts_layer:
304+
if is_fusion_moe_shared_experts_layer:
258305
if split_dim == 0:
259-
weight_to_load = loaded_weight[j * chunk_size : (j + 1) * chunk_size, :]
306+
weight_to_load = loaded_weight[
307+
j * chunk_size : (j + 1) * chunk_size, :
308+
]
260309
else:
261-
weight_to_load = loaded_weight[:, j * chunk_size : (j + 1) * chunk_size]
310+
weight_to_load = loaded_weight[
311+
:, j * chunk_size : (j + 1) * chunk_size
312+
]
313+
# Synthesize an expert-style name so expert mapping
314+
# can route it
262315
chunk_name = name.replace(
263316
"mlp.shared_experts",
264317
f"mlp.experts.{self.config.n_routed_experts + j}",
265318
)
266319

320+
# Use expert_params_mapping to locate the destination
321+
# param and delegate to its expert-aware weight_loader
322+
# with expert_id.
267323
for mapping in expert_params_mapping:
268324
param_name, weight_name, expert_id, shard_id = mapping
269325
if weight_name not in chunk_name:
270326
continue
271327

328+
# Anyway, this is an expert weight and should not be
329+
# attempted to load as other weights later
272330
is_expert_weight = True
331+
332+
# Do not modify `name` since the loop may continue here
333+
# Instead, create a new variable
273334
name_mapped = chunk_name.replace(weight_name, param_name)
274335

275336
if is_pp_missing_parameter(name_mapped, self):
276337
continue
277-
if name_mapped not in params_dict.keys():
278-
continue
338+
279339
param = params_dict[name_mapped]
280-
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
340+
# We should ask the weight loader to return success or
341+
# not here since otherwise we may skip experts with
342+
# other available replicas.
343+
weight_loader = typing.cast(
344+
Callable[..., bool], param.weight_loader
345+
)
281346
success = weight_loader(
282347
param,
283348
weight_to_load,
@@ -287,32 +352,39 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
287352
return_success=True,
288353
)
289354
if success:
290-
if not is_fuse_shared_experts_layer:
355+
if not is_fusion_moe_shared_experts_layer:
291356
name = name_mapped
292357
else:
293-
loaded_params.add(name_mapped)
358+
loaded_params.add(name_mapped)
294359
break
295360
else:
296361
if is_expert_weight:
362+
# We've checked that this is an expert weight
363+
# However it's not mapped locally to this rank
364+
# So we simply skip it
297365
continue
366+
367+
# Skip loading extra bias for GPTQ models.
298368
if name.endswith(".bias") and name not in params_dict:
299369
continue
370+
371+
# Remapping the name of FP8 kv-scale.
300372
name = maybe_remap_kv_scale_name(name, params_dict)
301373
if name is None:
302374
continue
375+
303376
if is_pp_missing_parameter(name, self):
304377
continue
305-
if name not in params_dict.keys():
306-
continue
307378

308379
param = params_dict[name]
309-
weight_loader = getattr(param, "weight_loader",
310-
default_weight_loader)
380+
weight_loader = getattr(
381+
param, "weight_loader", default_weight_loader
382+
)
311383
weight_loader(param, loaded_weight)
312-
if not is_fuse_shared_experts_layer:
384+
if not is_fusion_moe_shared_experts_layer:
313385
loaded_params.add(name)
314-
return loaded_params
315386

387+
return loaded_params
316388

317389
vllm.model_executor.models.deepseek_v2.DeepseekV2MoE = AscendDeepseekV2MoE
318390
DeepseekV2ForCausalLM.load_weights = CustomDeepseekV2ForCausalLM.load_weights

0 commit comments

Comments
 (0)