Skip to content

Commit fd00381

Browse files
update v0.11.2
Signed-off-by: chenmenglong <[email protected]>
1 parent e8160da commit fd00381

File tree

1 file changed

+51
-38
lines changed

1 file changed

+51
-38
lines changed

tests/ut/spec_decode/test_mtp_proposer.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import numpy as np
44
import pytest
55
import torch
6-
from vllm.config import (CacheConfig, CompilationConfig, ModelConfig,
7-
SchedulerConfig, SpeculativeConfig, VllmConfig)
6+
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
7+
ModelConfig, SchedulerConfig, SpeculativeConfig,
8+
VllmConfig)
89
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
910
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
1011
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -134,14 +135,19 @@ def get_layers_side_effect(vllm_config, cache_cls):
134135
with pytest.raises(AssertionError):
135136
proposer.load_model(mock_model)
136137

138+
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
137139
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
138-
def test_dummy_run(self, mock_set_context, vllm_config, runner):
140+
def test_dummy_run(self, mock_set_context, mock_get_forward_context,
141+
vllm_config, runner):
139142
# Setup
140143
proposer = MtpProposer(vllm_config, "npu", runner)
141144
proposer.model = MagicMock()
142145
runner._sync_metadata_across_dp.return_value = (8, 8, False)
143146
runner._select_moe_comm_method.return_value = "alltoall"
144147

148+
mock_get_forward_context = MagicMock()
149+
mock_get_forward_context.cudagraph_runtime_mode = None
150+
mock_get_forward_context.capturing = True
145151
# Execute
146152
proposer.dummy_run(8)
147153

@@ -153,6 +159,34 @@ def test_dummy_run(self, mock_set_context, vllm_config, runner):
153159
# Check that model was called correct number of times
154160
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
155161

162+
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
163+
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
164+
def test_dummy_run_full_graph(self, mock_set_context,
165+
mock_get_forward_context, vllm_config,
166+
runner):
167+
# Setup
168+
proposer = MtpProposer(vllm_config, "npu", runner)
169+
proposer.model = MagicMock()
170+
runner._sync_metadata_across_dp.return_value = (8, 8, False)
171+
runner._select_moe_comm_method.return_value = "alltoall"
172+
runner.attn_groups = []
173+
174+
mock_get_forward_context = MagicMock()
175+
mock_get_forward_context.cudagraph_runtime_mode = None
176+
mock_get_forward_context.capturing = True
177+
# Execute
178+
proposer.dummy_run(num_tokens=8,
179+
num_reqs=5,
180+
aclgraph_runtime_mode=CUDAGraphMode.FULL)
181+
182+
# Verify
183+
runner._sync_metadata_across_dp.assert_called_once()
184+
runner._select_moe_comm_method.assert_called_once()
185+
mock_set_context.assert_called()
186+
187+
# Check that model was called correct number of times
188+
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
189+
156190
def test_generate_token_ids(self):
157191
mock_deps = MagicMock()
158192
mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput)
@@ -208,50 +242,29 @@ def test_generate_token_ids(self):
208242
assert torch.equal(draft_token_ids, proposer._propose.return_value)
209243

210244
def test_prepare_next_token_ids_cpu(self):
211-
mock_requests = {}
212-
req1 = MagicMock(spec=CachedRequestState)
213-
req1.num_computed_tokens = 5
214-
req1.get_token_id = MagicMock(return_value=1000)
215-
mock_requests["req_001"] = req1
216-
req2 = MagicMock(spec=CachedRequestState)
217-
req2.num_computed_tokens = 8
218-
req2.get_token_id = MagicMock(return_value=2000)
219-
mock_requests["req_002"] = req2
220-
req3 = MagicMock(spec=CachedRequestState)
221-
req3.num_computed_tokens = 3
222-
req3.get_token_id = MagicMock(return_value=3000)
223-
mock_requests["req_003"] = req3
224-
225-
mock_gpu_input_batch = MagicMock(spec=InputBatch)
226-
mock_gpu_input_batch.req_ids = ["req_001", "req_002", "req_003"]
227-
228-
mock_num_scheduled_tokens = {"req_001": 2, "req_002": 0, "req_003": 1}
229-
230245
sampled_token_ids = [
231-
[101, 102, 103], # req001: return 103
232-
[], # req002: return fallback 2000
233-
[301] # req003: return 301
246+
np.array([10, 20, 30]),
247+
np.array([40, 50]),
248+
np.array([60])
234249
]
235250

251+
mock_requests = {} # dict[str, CachedRequestState]
252+
mock_gpu_batch = MagicMock()
253+
mock_gpu_batch.req_ids = ["req1", "req2", "req3"]
254+
mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0}
255+
236256
proposer = MagicMock(spec=MtpProposer)
237257
proposer.input_ids = MagicMock(device=torch.device("cpu"))
238258
proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__(
239259
proposer)
240-
241-
next_token_ids = proposer.prepare_next_token_ids_cpu(
260+
result = proposer.prepare_next_token_ids_cpu(
242261
sampled_token_ids=sampled_token_ids,
243262
requests=mock_requests,
244-
gpu_input_batch=mock_gpu_input_batch,
245-
num_scheduled_tokens=mock_num_scheduled_tokens)
246-
247-
expected = torch.tensor([103, 2000, 301],
248-
dtype=torch.int32,
249-
device="cpu")
250-
assert torch.equal(next_token_ids, expected)
263+
gpu_input_batch=mock_gpu_batch,
264+
num_scheduled_tokens=mock_num_scheduled)
251265

252-
mock_requests["req_002"].get_token_id.assert_called_once_with(8)
253-
mock_requests["req_001"].get_token_id.assert_not_called()
254-
mock_requests["req_003"].get_token_id.assert_not_called()
266+
assert torch.all(
267+
result == torch.tensor([30, 50, 60], dtype=torch.int32))
255268

256269
def test_prepare_next_token_ids_padded(self):
257270
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
@@ -266,7 +279,7 @@ def test_prepare_next_token_ids_padded(self):
266279
dtype=torch.int32,
267280
device="cpu")
268281

269-
mock_requests = {}
282+
mock_requests = {} # dict[str, CachedRequestState]
270283
req0 = MagicMock(spec=CachedRequestState)
271284
req0.get_token_id = MagicMock(return_value=1000)
272285
mock_requests["req_0"] = req0

0 commit comments

Comments
 (0)