Skip to content

Commit f7cf78d

Browse files
authored
[0.11.0][MTP][Aclgraph] Fix the support aclgraph with MTP (vllm-project#3912)
### What this PR does / why we need it? Fix 2 breaks of aclgraph with MTP: 1. deepseekmtp in vllm 0.11.0 does not support aclgraph and lack the `support_torch_compile` decorator 2. There is a d2h synchornization in the original forward of mtp predictor. The fix pr in vllm vllm-project/vllm#27643 As we'll fix it in vllm main, this fix pr is only needed in branch v0.11.0-dev The profling shows that MTP replays in aclgraph now: <img width="1612" height="1866" alt="a7d7f04155df4ed454b7eb20a92b2e2a" src="https://github.com/user-attachments/assets/eaa4b9ff-aeb0-416d-964f-5a06e497f155" /> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: MengqingCao <[email protected]>
1 parent 4169433 commit f7cf78d

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

vllm_ascend/patch/worker/patch_deepseek_mtp.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,54 @@
1+
from typing import Optional
2+
13
import torch
24
import torch.nn as nn
5+
import vllm
36
from transformers import PretrainedConfig
7+
from vllm.compilation.decorators import support_torch_compile
48
from vllm.config import VllmConfig
59
from vllm.model_executor.layers.layernorm import RMSNorm
610
from vllm.model_executor.layers.quantization import QuantizationConfig
711
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
8-
from vllm.model_executor.models.deepseek_mtp import \
9-
DeepSeekMultiTokenPredictorLayer
12+
from vllm.model_executor.models.deepseek_mtp import (
13+
DeepSeekMTP, DeepSeekMultiTokenPredictorLayer)
1014
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
1115
from vllm.model_executor.models.utils import maybe_prefix
1216

1317

18+
def forward(
19+
self,
20+
input_ids: torch.Tensor,
21+
positions: torch.Tensor,
22+
previous_hidden_states: torch.Tensor,
23+
inputs_embeds: Optional[torch.Tensor] = None,
24+
spec_step_index: int = 0,
25+
) -> torch.Tensor:
26+
assert inputs_embeds is not None
27+
# masking inputs at position 0, as not needed by MTP
28+
# Patch this for aclgraph support, as the original operation introduced d2h sync,
29+
# which breaks aclgraph
30+
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
31+
inputs_embeds = self.enorm(inputs_embeds)
32+
previous_hidden_states = self.hnorm(previous_hidden_states)
33+
34+
hidden_states = self.eh_proj(
35+
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
36+
37+
hidden_states, residual = self.mtp_block(positions=positions,
38+
hidden_states=hidden_states,
39+
residual=None)
40+
hidden_states = residual + hidden_states
41+
return hidden_states
42+
43+
44+
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
45+
@support_torch_compile
46+
class AscendDeepSeekMTP(DeepSeekMTP):
47+
48+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
49+
super().__init__(vllm_config=vllm_config, prefix=prefix)
50+
51+
1452
class SharedHead(nn.Module):
1553

1654
def __init__(
@@ -53,3 +91,4 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
5391

5492

5593
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
94+
vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
from vllm.model_executor.model_loader import get_model_loader
1212
from vllm.model_executor.model_loader.utils import (
1313
process_weights_after_loading, set_default_torch_dtype)
14-
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
1514
from vllm.v1.core.sched.output import SchedulerOutput
1615
from vllm.v1.sample.metadata import SamplingMetadata
1716
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1817

1918
from vllm_ascend.ascend_config import get_ascend_config
2019
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
2120
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
21+
from vllm_ascend.patch.worker.patch_deepseek_mtp import \
22+
AscendDeepSeekMTP as DeepSeekMTP
2223
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
2324
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
2425
TorchairDeepSeekMTP

0 commit comments

Comments
 (0)