Skip to content

Commit 2852995

Browse files
committed
move patch to patch_deepseek_mtp
Signed-off-by: MengqingCao <[email protected]>
1 parent bac4043 commit 2852995

File tree

4 files changed

+40
-49
lines changed

4 files changed

+40
-49
lines changed

vllm_ascend/patch/platform/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import vllm_ascend.patch.platform.patch_config # noqa
2020
import vllm_ascend.patch.platform.patch_distributed # noqa
2121
import vllm_ascend.patch.platform.patch_mamba_config # noqa
22-
import vllm_ascend.patch.platform.patch_mtp_predictor # noqa
2322
import vllm_ascend.patch.platform.patch_sched_yield # noqa
2423

2524
if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv(

vllm_ascend/patch/platform/patch_mtp_predictor.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

vllm_ascend/patch/worker/patch_deepseek_mtp.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,52 @@
11
import torch
22
import torch.nn as nn
3+
import vllm
34
from transformers import PretrainedConfig
5+
from vllm.compilation.decorators import support_torch_compile
46
from vllm.config import VllmConfig
57
from vllm.model_executor.layers.layernorm import RMSNorm
68
from vllm.model_executor.layers.quantization import QuantizationConfig
79
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
8-
from vllm.model_executor.models.deepseek_mtp import \
9-
DeepSeekMultiTokenPredictorLayer
10+
from vllm.model_executor.models.deepseek_mtp import (
11+
DeepSeekMTP, DeepSeekMultiTokenPredictorLayer)
1012
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
1113
from vllm.model_executor.models.utils import maybe_prefix
1214

1315

16+
def forward(
17+
self,
18+
input_ids: torch.Tensor,
19+
positions: torch.Tensor,
20+
previous_hidden_states: torch.Tensor,
21+
inputs_embeds: torch.Tensor | None = None,
22+
spec_step_index: int = 0,
23+
) -> torch.Tensor:
24+
assert inputs_embeds is not None
25+
# masking inputs at position 0, as not needed by MTP
26+
# Patch this for aclgraph support, as the original operation introduced d2h sync,
27+
# which breaks aclgraph
28+
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
29+
inputs_embeds = self.enorm(inputs_embeds)
30+
previous_hidden_states = self.hnorm(previous_hidden_states)
31+
32+
hidden_states = self.eh_proj(
33+
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
34+
35+
hidden_states, residual = self.mtp_block(positions=positions,
36+
hidden_states=hidden_states,
37+
residual=None)
38+
hidden_states = residual + hidden_states
39+
return hidden_states
40+
41+
42+
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
43+
@support_torch_compile
44+
class AscendDeepSeekMTP(DeepSeekMTP):
45+
46+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
47+
super().__init__(vllm_config=vllm_config, prefix=prefix)
48+
49+
1450
class SharedHead(nn.Module):
1551

1652
def __init__(
@@ -53,3 +89,4 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
5389

5490

5591
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
92+
vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm_ascend.ascend_config import get_ascend_config
1919
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
2020
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
21-
from vllm_ascend.patch.platform.patch_mtp_predictor import \
21+
from vllm_ascend.patch.worker.patch_deepseek_mtp import \
2222
AscendDeepSeekMTP as DeepSeekMTP
2323
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
2424
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \

0 commit comments

Comments
 (0)