Skip to content

Commit 7702e6b

Browse files
committed
[0.11.0][MTP][Aclgraph] Fix the support aclgraph with MTP
Signed-off-by: MengqingCao <[email protected]>
1 parent c506ba6 commit 7702e6b

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

vllm_ascend/patch/platform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import vllm_ascend.patch.platform.patch_distributed # noqa
2121
import vllm_ascend.patch.platform.patch_mamba_config # noqa
2222
import vllm_ascend.patch.platform.patch_sched_yield # noqa
23+
import vllm_ascend.patch.platform.patch_mtp_predictor # noqa
2324

2425
if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv(
2526
"EXPERT_MAP_RECORD", "false") == "true":
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
6+
import vllm
7+
8+
from vllm.config import VllmConfig
9+
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
10+
from vllm.compilation.decorators import support_torch_compile
11+
12+
13+
def forward(
14+
self,
15+
input_ids: torch.Tensor,
16+
positions: torch.Tensor,
17+
previous_hidden_states: torch.Tensor,
18+
inputs_embeds: torch.Tensor | None = None,
19+
spec_step_index: int = 0,
20+
) -> torch.Tensor:
21+
assert inputs_embeds is not None
22+
# masking inputs at position 0, as not needed by MTP
23+
# Patch this for aclgraph support, as the original operation introduced d2h sync,
24+
# which breaks aclgraph
25+
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
26+
inputs_embeds = self.enorm(inputs_embeds)
27+
previous_hidden_states = self.hnorm(previous_hidden_states)
28+
29+
hidden_states = self.eh_proj(
30+
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
31+
)
32+
33+
hidden_states, residual = self.mtp_block(
34+
positions=positions, hidden_states=hidden_states, residual=None
35+
)
36+
hidden_states = residual + hidden_states
37+
return hidden_states
38+
39+
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
40+
@support_torch_compile
41+
class AscendDeepSeekMTP(DeepSeekMTP):
42+
43+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
44+
super().__init__(vllm_config=vllm_config, prefix=prefix)
45+
46+
47+
vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward
48+
vllm.model_executor.models.deepseek_mtp.DeepSeekMTP = AscendDeepSeekMTP

0 commit comments

Comments
 (0)