Skip to content

Commit 866347a

Browse files
zzhx1wangxiyuan
andauthored
Deepseek Mtp model uses the lm_head and embedding from the main model (#2790)
### What this PR does / why we need it? In the Deepseek technical report, it is mentioned that the embedding and lmhead layers of the MTP layer are shared with the main model, but the current implementation independently loads the complete embedding and lmhead. In the Deepseek-R1 model, their weight sizes are 129280*7168 in fp16 format, which is 1.72G. This PR fixes the MTP layer to use the lmhead and embedding of the main model, saving 3.45G of GPU memory in the pure DP scenario. The current process will first create temporary spaces for the embedding and lmhead in the mtp layer, then I will call torch.equal to determine if the two matrices are the same. If they are the same, they will be reused, and the previous tensor will be released. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: zzhx1 <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent 9fbcfa3 commit 866347a

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,16 @@ def load_model(self, model) -> None:
200200
process_weights_after_loading(self.model, draft_model_config,
201201
target_device)
202202

203+
# check if mtp model use main model's embedding and LMhead
204+
main_model = model
205+
if torch.equal(self.model.model.embed_tokens.weight,
206+
main_model.model.embed_tokens.weight):
207+
self.model.model.embed_tokens = main_model.model.embed_tokens
208+
for _, layer_module in self.model.model.layers.items():
209+
if torch.equal(layer_module.shared_head.head.weight,
210+
main_model.lm_head.weight):
211+
layer_module.shared_head.head = main_model.lm_head
212+
203213
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
204214
):
205215
self.update_stream: torch.npu.Stream = torch.npu.Stream()

0 commit comments

Comments
 (0)