1717
1818from vllm_ascend .ascend_config import get_ascend_config
1919from vllm_ascend .ascend_forward_context import set_ascend_forward_context
20+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
2021from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
2122from vllm_ascend .patch .worker .patch_deepseek_mtp import \
2223 AscendDeepSeekMTP as DeepSeekMTP
24+ from vllm_ascend .compilation .acl_graph import (ACLGraphWrapper ,
25+ set_mtp_graph_params ,
26+ update_mla_attn_params )
2327from vllm_ascend .spec_decode .interface import Proposer , SpecDcodeType
2428from vllm_ascend .torchair .models .torchair_deepseek_mtp import \
2529 TorchairDeepSeekMTP
@@ -71,6 +75,23 @@ def __init__(
7175 self .use_sparse = hasattr (vllm_config .model_config .hf_config ,
7276 "index_topk" )
7377
78+ self .actual_seq_lengths_q = list (
79+ range (1 , self .runner .max_num_tokens + 1 , 1 ))
80+ self .query_start_loc = torch .zeros (self .runner .max_num_reqs + 1 ,
81+ dtype = torch .int32 ,
82+ device = self .device )
83+ self .query_start_loc_cpu = torch .zeros (self .runner .max_num_reqs + 1 ,
84+ dtype = torch .int32 ,
85+ device = "cpu" ,
86+ pin_memory = True )
87+ self .slot_mapping = torch .zeros (self .runner .max_num_tokens ,
88+ dtype = torch .int32 ,
89+ device = self .device )
90+ self .seq_lens_cpu = torch .zeros (self .runner .max_num_reqs ,
91+ dtype = torch .int32 ,
92+ device = "cpu" ,
93+ pin_memory = True )
94+
7495 def load_model (self , model ) -> None :
7596 loader = get_model_loader (self .vllm_config .load_config )
7697
@@ -106,6 +127,15 @@ def load_model(self, model) -> None:
106127 process_weights_after_loading (self .model , draft_model_config ,
107128 target_device )
108129
130+ if self .vllm_config .compilation_config .cudagraph_mode .has_full_cudagraphs (
131+ ):
132+ self .update_stream : torch .npu .Stream = torch .npu .Stream ()
133+ set_mtp_graph_params (
134+ self .vllm_config .compilation_config .cudagraph_capture_sizes )
135+ self .model = ACLGraphWrapper (self .model ,
136+ self .vllm_config ,
137+ runtime_mode = CUDAGraphMode .FULL )
138+
109139 @torch .inference_mode ()
110140 def dummy_run (self ,
111141 num_tokens : int ,
@@ -131,7 +161,7 @@ def dummy_run(self,
131161 skip_attn = False
132162 if skip_attn :
133163 attn_metadata = None
134- else :
164+ elif is_running_torchair :
135165 common_attn_metadata = TorchairCommonAttentionMetadata (
136166 num_reqs = num_reqs ,
137167 num_actual_tokens = 1 ,
@@ -142,6 +172,56 @@ def dummy_run(self,
142172 )
143173 attn_metadata = self .runner .attn_metadata_builder .build_torchair_graph_dummy (
144174 common_attn_metadata )
175+ elif aclgraph_runtime_mode == CUDAGraphMode .FULL :
176+ assert with_prefill is False , \
177+ "Full decode graph only supports uniform batch now."
178+ num_reqs = num_tokens
179+ max_seq_lens = self .runner .model_config .max_model_len
180+ self .seq_lens_cpu [:num_reqs ] = max_seq_lens
181+ self .seq_lens_cpu [num_reqs :] = 0
182+ if len (self .runner .attn_groups ) > 0 :
183+ num_computed_tokens_cpu = (
184+ self .runner .input_batch .
185+ num_computed_tokens_cpu_tensor [:num_reqs ])
186+ query_start_loc = torch .tensor (
187+ [0 ] + self .actual_seq_lengths_q [:num_reqs ],
188+ device = self .runner .device ,
189+ dtype = torch .int32 )
190+ self .query_start_loc [:num_reqs + 1 ].copy_ (query_start_loc )
191+ common_attn_metadata = AscendCommonAttentionMetadata (
192+ query_start_loc = self .query_start_loc [:num_reqs + 1 ],
193+ query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs +
194+ 1 ],
195+ seq_lens_cpu = self .seq_lens_cpu ,
196+ seq_lens = self .seq_lens_cpu [:num_reqs ],
197+ num_reqs = num_reqs ,
198+ num_actual_tokens = num_tokens ,
199+ max_query_len = self .num_speculative_tokens + 1 ,
200+ num_computed_tokens_cpu = num_computed_tokens_cpu ,
201+ actual_seq_lengths_q = self .runner .actual_seq_lengths_q ,
202+ block_table_tensor = self .runner .input_batch .block_table [0 ].
203+ get_device_tensor ()[:num_reqs ],
204+ slot_mapping = self .slot_mapping ,
205+ positions = self .positions ,
206+ attn_mask = self .runner .attn_mask ,
207+ spec_attn_mask = self .runner .spec_attn_mask ,
208+ attn_state = self .runner .attn_state ,
209+ decode_token_per_req = self .runner .decode_token_per_req ,
210+ cos = self .runner .cos , # 考虑mrope,是否可以共用?
211+ sin = self .runner .sin ,
212+ )
213+
214+ builder = self .runner .attn_groups [0 ][0 ].get_metadata_builder ()
215+ attn_metadata_mtp = builder .build_for_graph_capture (
216+ common_attn_metadata , AscendAttentionState .SpecDecoding ,
217+ self .runner .get_model ())
218+ attn_metadata = {}
219+ for layer_name in self .attn_layer_name :
220+ attn_metadata [layer_name ] = attn_metadata_mtp
221+ else :
222+ attn_metadata = None
223+ else :
224+ attn_metadata = None
145225
146226 input_ids = self .input_ids [:num_tokens ]
147227 positions = self .positions [:num_tokens ]
@@ -158,7 +238,8 @@ def dummy_run(self,
158238 in_profile_run = self .runner .in_profile_run ,
159239 num_actual_tokens = 0 ,
160240 aclgraph_runtime_mode = aclgraph_runtime_mode ,
161- batch_descriptor = batch_descriptor ):
241+ batch_descriptor = batch_descriptor ,
242+ is_mtp_model = True ):
162243 if is_running_torchair :
163244 assert attn_metadata is not None
164245 torch ._dynamo .mark_static (input_ids )
@@ -188,6 +269,14 @@ def dummy_run(self,
188269 self .model (input_ids = input_ids ,
189270 positions = positions ,
190271 hidden_states = previous_hidden_states )
272+ forward_context = get_forward_context ()
273+ if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL and \
274+ not forward_context .capturing :
275+ if self .vllm_config .model_config .use_mla :
276+ update_mla_attn_params (
277+ self .update_stream , forward_context ,
278+ positions .shape [0 ],
279+ self .vllm_config .speculative_config )
191280 if with_prefill :
192281 break
193282
@@ -260,7 +349,8 @@ def generate_token_ids(self,
260349 cu_num_tokens = cu_num_tokens ,
261350 block_table = attn_metadata .block_tables ,
262351 sampling_metadata = sampling_metadata ,
263- token_indices = accepted_token_indices )
352+ token_indices = accepted_token_indices ,
353+ scheduler_output = scheduler_output )
264354 spec_token_ids = draft_token_ids .tolist ()
265355 return spec_token_ids
266356
@@ -322,6 +412,17 @@ def _prepare_inputs(
322412 target_positions = positions [token_indices ]
323413 target_hidden_states = hidden_states [token_indices ]
324414 target_slot_mapping = slot_mapping [token_indices ]
415+
416+ batch_size = num_rejected_tokens .shape [0 ]
417+ self .query_start_loc [:batch_size + 1 ].copy_ (cu_num_tokens [:batch_size +
418+ 1 ])
419+ self .query_start_loc_cpu [:batch_size + 1 ].copy_ (
420+ self .query_start_loc [:batch_size + 1 ], non_blocking = True )
421+ target_positions_len = target_positions .shape [0 ]
422+ self .positions [:target_positions_len ].copy_ (target_positions )
423+ target_slot_mapping_len = target_slot_mapping .shape [0 ]
424+ self .slot_mapping [:target_slot_mapping_len ].copy_ (target_slot_mapping )
425+
325426 return cu_num_tokens , token_indices , target_token_ids , target_positions , target_hidden_states , target_slot_mapping
326427
327428 def _propose (
@@ -341,7 +442,8 @@ def _propose(
341442 # [batch_size, max_num_blocks_per_req]
342443 block_table : torch .Tensor ,
343444 sampling_metadata : SamplingMetadata ,
344- token_indices = None ) -> torch .Tensor :
445+ token_indices = None ,
446+ scheduler_output : SchedulerOutput = None ) -> torch .Tensor :
345447 num_tokens = target_token_ids .shape [0 ]
346448 batch_size = next_token_ids .shape [0 ]
347449 last_token_indices = cu_num_tokens [1 :] - 1
@@ -385,18 +487,20 @@ def _propose(
385487
386488 seq_lens = target_positions [last_token_indices ] + 1
387489 seq_lens = seq_lens .int ()
490+ seq_lens_len = seq_lens .shape [0 ]
491+ self .seq_lens_cpu [:seq_lens_len ].copy_ (seq_lens , non_blocking = True )
388492 common_attn_metadata = AscendCommonAttentionMetadata (
389- query_start_loc = cu_num_tokens [:batch_size + 1 ],
390- query_start_loc_cpu = cu_num_tokens [:batch_size + 1 ]. cpu () ,
391- seq_lens_cpu = seq_lens . cpu () ,
493+ query_start_loc = self . query_start_loc [:batch_size + 1 ],
494+ query_start_loc_cpu = self . query_start_loc_cpu [:batch_size + 1 ],
495+ seq_lens_cpu = self . seq_lens_cpu [: seq_lens_len ] ,
392496 num_reqs = batch_size ,
393497 num_actual_tokens = num_tokens ,
394498 max_query_len = max_query_len ,
395499 actual_seq_lengths_q = self .runner .actual_seq_lengths_q ,
396500 block_table_tensor = self .runner .input_batch .block_table [0 ].
397501 get_device_tensor (),
398- slot_mapping = target_slot_mapping ,
399- positions = target_positions ,
502+ slot_mapping = self . slot_mapping [: target_slot_mapping . shape [ 0 ]] ,
503+ positions = self . positions [: target_positions . shape [ 0 ]] ,
400504 attn_mask = self .runner .attn_mask ,
401505 spec_attn_mask = self .runner .spec_attn_mask ,
402506 attn_state = self .runner .attn_state ,
@@ -434,8 +538,18 @@ def _propose(
434538
435539 moe_comm_type = self .runner ._select_moe_comm_method (
436540 num_input_tokens , with_prefill )
437- batch_descriptor = BatchDescriptor (num_tokens = num_input_tokens ,
438- uniform_decode = False )
541+
542+ if scheduler_output :
543+ uniform_decode = (max_query_len in list (
544+ range (1 , self .num_speculative_tokens + 2 ))) and (
545+ scheduler_output .total_num_scheduled_tokens //
546+ (self .num_speculative_tokens + 2 - max_query_len )
547+ == self .runner .input_batch .num_reqs * max_query_len )
548+ batch_descriptor = BatchDescriptor (num_tokens = num_input_tokens ,
549+ uniform_decode = uniform_decode )
550+ else :
551+ batch_descriptor = BatchDescriptor (num_tokens = num_input_tokens ,
552+ uniform_decode = False )
439553 aclgraph_runtime_mode , batch_descriptor = \
440554 self .runner .aclgraph_dispatcher .dispatch (batch_descriptor )
441555
@@ -451,7 +565,8 @@ def _propose(
451565 aclgraph_runtime_mode = aclgraph_runtime_mode ,
452566 batch_descriptor = batch_descriptor ,
453567 in_profile_run = self .runner .in_profile_run ,
454- num_actual_tokens = num_tokens ):
568+ num_actual_tokens = num_tokens ,
569+ is_mtp_model = True ):
455570 with ProfileExecuteDuration ().capture_async ('mtp_forward' ):
456571 model_kwargs = {}
457572 model_kwargs ["attn_metadata" ] = attn_metadata
@@ -475,6 +590,13 @@ def _propose(
475590 positions = self .positions [:num_input_tokens ],
476591 hidden_states = self .hidden_states [:num_input_tokens ]
477592 )
593+ forward_context = get_forward_context ()
594+ if forward_context .cudagraph_runtime_mode == CUDAGraphMode .FULL :
595+ if self .vllm_config .model_config .use_mla :
596+ update_mla_attn_params (
597+ self .update_stream , forward_context ,
598+ num_input_tokens ,
599+ self .vllm_config .speculative_config )
478600
479601 num_indices = last_token_indices .shape [0 ]
480602 if lmhead_tp_enable ():
0 commit comments