|
1 | 1 | import torch |
2 | 2 | import torch.nn as nn |
| 3 | +import vllm |
3 | 4 | from transformers import PretrainedConfig |
| 5 | +from vllm.compilation.decorators import support_torch_compile |
4 | 6 | from vllm.config import VllmConfig |
5 | 7 | from vllm.model_executor.layers.layernorm import RMSNorm |
6 | 8 | from vllm.model_executor.layers.quantization import QuantizationConfig |
7 | 9 | from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead |
8 | | -from vllm.model_executor.models.deepseek_mtp import \ |
9 | | - DeepSeekMultiTokenPredictorLayer |
| 10 | +from vllm.model_executor.models.deepseek_mtp import ( |
| 11 | + DeepSeekMTP, DeepSeekMultiTokenPredictorLayer) |
10 | 12 | from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer |
11 | 13 | from vllm.model_executor.models.utils import maybe_prefix |
12 | 14 |
|
13 | 15 |
|
| 16 | +def forward( |
| 17 | + self, |
| 18 | + input_ids: torch.Tensor, |
| 19 | + positions: torch.Tensor, |
| 20 | + previous_hidden_states: torch.Tensor, |
| 21 | + inputs_embeds: torch.Tensor | None = None, |
| 22 | + spec_step_index: int = 0, |
| 23 | +) -> torch.Tensor: |
| 24 | + assert inputs_embeds is not None |
| 25 | + # masking inputs at position 0, as not needed by MTP |
| 26 | + # Patch this for aclgraph support, as the original operation introduced d2h sync, |
| 27 | + # which breaks aclgraph |
| 28 | + inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) |
| 29 | + inputs_embeds = self.enorm(inputs_embeds) |
| 30 | + previous_hidden_states = self.hnorm(previous_hidden_states) |
| 31 | + |
| 32 | + hidden_states = self.eh_proj( |
| 33 | + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) |
| 34 | + |
| 35 | + hidden_states, residual = self.mtp_block(positions=positions, |
| 36 | + hidden_states=hidden_states, |
| 37 | + residual=None) |
| 38 | + hidden_states = residual + hidden_states |
| 39 | + return hidden_states |
| 40 | + |
| 41 | + |
| 42 | +# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0 |
| 43 | +@support_torch_compile |
| 44 | +class AscendDeepSeekMTP(DeepSeekMTP): |
| 45 | + |
| 46 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 47 | + super().__init__(vllm_config=vllm_config, prefix=prefix) |
| 48 | + |
| 49 | + |
14 | 50 | class SharedHead(nn.Module): |
15 | 51 |
|
16 | 52 | def __init__( |
@@ -53,3 +89,4 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None: |
53 | 89 |
|
54 | 90 |
|
55 | 91 | DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init |
| 92 | +vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward |
0 commit comments