Skip to content

Commit bac4043

Browse files
committed
fix
Signed-off-by: MengqingCao <[email protected]>
1 parent 7702e6b commit bac4043

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

vllm_ascend/patch/platform/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import vllm_ascend.patch.platform.patch_config # noqa
2020
import vllm_ascend.patch.platform.patch_distributed # noqa
2121
import vllm_ascend.patch.platform.patch_mamba_config # noqa
22-
import vllm_ascend.patch.platform.patch_sched_yield # noqa
2322
import vllm_ascend.patch.platform.patch_mtp_predictor # noqa
23+
import vllm_ascend.patch.platform.patch_sched_yield # noqa
2424

2525
if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv(
2626
"EXPERT_MAP_RECORD", "false") == "true":

vllm_ascend/patch/platform/patch_mtp_predictor.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import torch
5-
65
import vllm
7-
6+
from vllm.compilation.decorators import support_torch_compile
87
from vllm.config import VllmConfig
98
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
10-
from vllm.compilation.decorators import support_torch_compile
119

1210

1311
def forward(
@@ -27,15 +25,15 @@ def forward(
2725
previous_hidden_states = self.hnorm(previous_hidden_states)
2826

2927
hidden_states = self.eh_proj(
30-
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
31-
)
28+
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
3229

33-
hidden_states, residual = self.mtp_block(
34-
positions=positions, hidden_states=hidden_states, residual=None
35-
)
30+
hidden_states, residual = self.mtp_block(positions=positions,
31+
hidden_states=hidden_states,
32+
residual=None)
3633
hidden_states = residual + hidden_states
3734
return hidden_states
3835

36+
3937
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
4038
@support_torch_compile
4139
class AscendDeepSeekMTP(DeepSeekMTP):
@@ -45,4 +43,3 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
4543

4644

4745
vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward
48-
vllm.model_executor.models.deepseek_mtp.DeepSeekMTP = AscendDeepSeekMTP

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
from vllm.model_executor.model_loader import get_model_loader
1212
from vllm.model_executor.model_loader.utils import (
1313
process_weights_after_loading, set_default_torch_dtype)
14-
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
1514
from vllm.v1.core.sched.output import SchedulerOutput
1615
from vllm.v1.sample.metadata import SamplingMetadata
1716
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1817

1918
from vllm_ascend.ascend_config import get_ascend_config
2019
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
2120
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
21+
from vllm_ascend.patch.platform.patch_mtp_predictor import \
22+
AscendDeepSeekMTP as DeepSeekMTP
2223
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
2324
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
2425
TorchairDeepSeekMTP

0 commit comments

Comments
 (0)