33import numpy as np
44import pytest
55import torch
6- from vllm .config import (CacheConfig , CompilationConfig , ModelConfig ,
7- SchedulerConfig , SpeculativeConfig , VllmConfig )
6+ from vllm .config import (CacheConfig , CompilationConfig , CUDAGraphMode ,
7+ ModelConfig , SchedulerConfig , SpeculativeConfig ,
8+ VllmConfig )
89from vllm .model_executor .layers .attention_layer_base import AttentionLayerBase
910from vllm .model_executor .models .deepseek_v2 import DeepseekV32IndexerCache
1011from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
@@ -134,14 +135,19 @@ def get_layers_side_effect(vllm_config, cache_cls):
134135 with pytest .raises (AssertionError ):
135136 proposer .load_model (mock_model )
136137
138+ @patch ("vllm_ascend.spec_decode.mtp_proposer.get_forward_context" )
137139 @patch ("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context" )
138- def test_dummy_run (self , mock_set_context , vllm_config , runner ):
140+ def test_dummy_run (self , mock_set_context , mock_get_forward_context ,
141+ vllm_config , runner ):
139142 # Setup
140143 proposer = MtpProposer (vllm_config , "npu" , runner )
141144 proposer .model = MagicMock ()
142145 runner ._sync_metadata_across_dp .return_value = (8 , 8 , False )
143146 runner ._select_moe_comm_method .return_value = "alltoall"
144147
148+ mock_get_forward_context = MagicMock ()
149+ mock_get_forward_context .cudagraph_runtime_mode = None
150+ mock_get_forward_context .capturing = True
145151 # Execute
146152 proposer .dummy_run (8 )
147153
@@ -153,6 +159,34 @@ def test_dummy_run(self, mock_set_context, vllm_config, runner):
153159 # Check that model was called correct number of times
154160 assert proposer .model .call_count == vllm_config .speculative_config .num_speculative_tokens
155161
162+ @patch ("vllm_ascend.spec_decode.mtp_proposer.get_forward_context" )
163+ @patch ("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context" )
164+ def test_dummy_run_full_graph (self , mock_set_context ,
165+ mock_get_forward_context , vllm_config ,
166+ runner ):
167+ # Setup
168+ proposer = MtpProposer (vllm_config , "npu" , runner )
169+ proposer .model = MagicMock ()
170+ runner ._sync_metadata_across_dp .return_value = (8 , 8 , False )
171+ runner ._select_moe_comm_method .return_value = "alltoall"
172+ runner .attn_groups = []
173+
174+ mock_get_forward_context = MagicMock ()
175+ mock_get_forward_context .cudagraph_runtime_mode = None
176+ mock_get_forward_context .capturing = True
177+ # Execute
178+ proposer .dummy_run (num_tokens = 8 ,
179+ num_reqs = 5 ,
180+ aclgraph_runtime_mode = CUDAGraphMode .FULL )
181+
182+ # Verify
183+ runner ._sync_metadata_across_dp .assert_called_once ()
184+ runner ._select_moe_comm_method .assert_called_once ()
185+ mock_set_context .assert_called ()
186+
187+ # Check that model was called correct number of times
188+ assert proposer .model .call_count == vllm_config .speculative_config .num_speculative_tokens
189+
156190 def test_generate_token_ids (self ):
157191 mock_deps = MagicMock ()
158192 mock_deps .scheduler_output = MagicMock (spec = SchedulerOutput )
@@ -208,50 +242,29 @@ def test_generate_token_ids(self):
208242 assert torch .equal (draft_token_ids , proposer ._propose .return_value )
209243
210244 def test_prepare_next_token_ids_cpu (self ):
211- mock_requests = {}
212- req1 = MagicMock (spec = CachedRequestState )
213- req1 .num_computed_tokens = 5
214- req1 .get_token_id = MagicMock (return_value = 1000 )
215- mock_requests ["req_001" ] = req1
216- req2 = MagicMock (spec = CachedRequestState )
217- req2 .num_computed_tokens = 8
218- req2 .get_token_id = MagicMock (return_value = 2000 )
219- mock_requests ["req_002" ] = req2
220- req3 = MagicMock (spec = CachedRequestState )
221- req3 .num_computed_tokens = 3
222- req3 .get_token_id = MagicMock (return_value = 3000 )
223- mock_requests ["req_003" ] = req3
224-
225- mock_gpu_input_batch = MagicMock (spec = InputBatch )
226- mock_gpu_input_batch .req_ids = ["req_001" , "req_002" , "req_003" ]
227-
228- mock_num_scheduled_tokens = {"req_001" : 2 , "req_002" : 0 , "req_003" : 1 }
229-
230245 sampled_token_ids = [
231- [ 101 , 102 , 103 ], # req001: return 103
232- [], # req002: return fallback 2000
233- [ 301 ] # req003: return 301
246+ np . array ([ 10 , 20 , 30 ]),
247+ np . array ([ 40 , 50 ]),
248+ np . array ([ 60 ])
234249 ]
235250
251+ mock_requests = {} # dict[str, CachedRequestState]
252+ mock_gpu_batch = MagicMock ()
253+ mock_gpu_batch .req_ids = ["req1" , "req2" , "req3" ]
254+ mock_num_scheduled = {"req1" : 0 , "req2" : 0 , "req3" : 0 }
255+
236256 proposer = MagicMock (spec = MtpProposer )
237257 proposer .input_ids = MagicMock (device = torch .device ("cpu" ))
238258 proposer .prepare_next_token_ids_cpu = MtpProposer .prepare_next_token_ids_cpu .__get__ (
239259 proposer )
240-
241- next_token_ids = proposer .prepare_next_token_ids_cpu (
260+ result = proposer .prepare_next_token_ids_cpu (
242261 sampled_token_ids = sampled_token_ids ,
243262 requests = mock_requests ,
244- gpu_input_batch = mock_gpu_input_batch ,
245- num_scheduled_tokens = mock_num_scheduled_tokens )
246-
247- expected = torch .tensor ([103 , 2000 , 301 ],
248- dtype = torch .int32 ,
249- device = "cpu" )
250- assert torch .equal (next_token_ids , expected )
263+ gpu_input_batch = mock_gpu_batch ,
264+ num_scheduled_tokens = mock_num_scheduled )
251265
252- mock_requests ["req_002" ].get_token_id .assert_called_once_with (8 )
253- mock_requests ["req_001" ].get_token_id .assert_not_called ()
254- mock_requests ["req_003" ].get_token_id .assert_not_called ()
266+ assert torch .all (
267+ result == torch .tensor ([30 , 50 , 60 ], dtype = torch .int32 ))
255268
256269 def test_prepare_next_token_ids_padded (self ):
257270 mock_common_attn_metadata = MagicMock (spec = CommonAttentionMetadata )
@@ -266,7 +279,7 @@ def test_prepare_next_token_ids_padded(self):
266279 dtype = torch .int32 ,
267280 device = "cpu" )
268281
269- mock_requests = {}
282+ mock_requests = {} # dict[str, CachedRequestState]
270283 req0 = MagicMock (spec = CachedRequestState )
271284 req0 .get_token_id = MagicMock (return_value = 1000 )
272285 mock_requests ["req_0" ] = req0
0 commit comments