Skip to content

Commit 3642b64

Browse files
authored
bugfix for mtp with multistream_moe (#3419)
### What this PR does / why we need it? when infer deepseek mtp layer with multistream_moe, we should pass a boolean to evaluate this feature and fix bugs when we are in mtp layer - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: zouyida2052 <[email protected]>
1 parent c2c1db7 commit 3642b64

File tree

5 files changed

+22
-11
lines changed

5 files changed

+22
-11
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def test_mtp_torchair_correctness(
4141
"use_cached_graph": False,
4242
"graph_batch_sizes": [1, 2, 4],
4343
},
44+
"multistream_overlap_shared_expert": "True"
4445
}) as ref_llm:
4546
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
4647
with VllmRunner(model_name,
@@ -60,7 +61,8 @@ def test_mtp_torchair_correctness(
6061
"enabled": True,
6162
"use_cached_graph": False,
6263
"graph_batch_sizes": [1, 2, 4],
63-
}
64+
},
65+
"multistream_overlap_shared_expert": "True"
6466
}) as spec_llm:
6567
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
6668

tests/ut/torchair/models/test_torchair_deepseek_mtp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def setup_mtp_layer(self, mocker: MockerFixture):
1717
config = PretrainedConfig(vocab_size=1000,
1818
hidden_size=768,
1919
rms_norm_eps=1e-5)
20+
mocker.patch(
21+
'vllm_ascend.torchair.models.torchair_deepseek_mtp.get_tensor_model_parallel_world_size',
22+
return_value=1)
2023
mocker.patch(
2124
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
2225
return_value=None)
@@ -56,6 +59,8 @@ def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
5659
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
5760
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
5861
torch.randn(2, 3, 768))
62+
mtp_layer.enorm.return_value = torch.randn(2, 3, 768)
63+
mtp_layer.hnorm.return_value = torch.randn(2, 3, 768)
5964

6065
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
6166
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
@@ -65,7 +70,7 @@ def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
6570

6671
output = mtp_layer(input_ids, positions, kv_cache, None,
6772
previous_hidden_states, inputs_embeds, 0)
68-
assert output.shape == (2, 3, 768)
73+
assert output.shape == (3, 768)
6974

7075

7176
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):

vllm_ascend/attention/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ def split_decodes_and_prefills(
103103
return num_reqs, 0, num_tokens, 0
104104

105105
first_prefill = is_prefill.int().argmax(dim=-1).item()
106-
assert torch.all(query_lens[first_prefill:] > decode_threshold)
107-
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
108106
num_decodes = first_prefill
109107
num_prefills = num_reqs - num_decodes
110108
num_decode_tokens = query_start_loc[first_prefill].item()

vllm_ascend/torchair/models/torchair_deepseek_mtp.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from transformers import PretrainedConfig
2525
from vllm.attention.backends.abstract import AttentionMetadata
2626
from vllm.config import CacheConfig, ModelConfig, VllmConfig
27+
from vllm.distributed import get_tensor_model_parallel_world_size
2728
from vllm.model_executor.layers.layernorm import RMSNorm
2829
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2930
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -66,6 +67,7 @@ def __init__(
6667
) -> None:
6768
nn.Module.__init__(self)
6869

70+
self.tp_size = get_tensor_model_parallel_world_size()
6971
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
7072
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
7173
self.eh_proj = nn.Linear(config.hidden_size * 2,
@@ -100,11 +102,15 @@ def forward(
100102
hidden_states = self.eh_proj(
101103
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
102104

103-
hidden_states, residual = self.mtp_block(positions=positions,
104-
hidden_states=hidden_states,
105-
kv_cache=kv_cache,
106-
attn_metadata=attn_metadata,
107-
residual=None)
105+
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
106+
107+
hidden_states, residual = self.mtp_block(
108+
positions=positions,
109+
hidden_states=hidden_states,
110+
residual=None,
111+
kv_cache=kv_cache,
112+
attn_metadata=attn_metadata,
113+
replace_allreduce=replace_allreduce)
108114
hidden_states = residual + hidden_states
109115
return hidden_states
110116

vllm_ascend/torchair/models/torchair_deepseek_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ def forward(
975975
# to save npu memory because they're no longer used.
976976
dispose_tensor(previous_hidden_states)
977977
dispose_tensor(previous_residual)
978-
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
978+
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace and self.layer_idx < self.layers:
979979
hidden_states = tensor_model_parallel_all_gather(hidden_states,
980980
dim=0)
981981

@@ -1034,7 +1034,7 @@ def forward(
10341034
# The scaling of DeepseekV2MOE output would be done in the forward
10351035
# of DeepseekV2MOE
10361036
hidden_states *= 1. / self.routed_scaling_factor
1037-
if mla_moe_communication and self.layer_idx == self.layers - 1:
1037+
if mla_moe_communication and self.layer_idx >= self.layers - 1:
10381038
hidden_states = tensor_model_parallel_all_gather(hidden_states,
10391039
dim=0)
10401040
residual = tensor_model_parallel_all_gather(residual, dim=0)

0 commit comments

Comments
 (0)