@@ -75,24 +75,6 @@ def __init__(
7575 self .use_sparse = hasattr (vllm_config .model_config .hf_config ,
7676 "index_topk" )
7777
78- self .query_start_loc = torch .zeros (
79- self .runner .max_num_reqs * (self .num_speculative_tokens + 1 ) + 1 ,
80- dtype = torch .int32 ,
81- device = self .device )
82- self .query_start_loc_cpu = torch .zeros (
83- self .runner .max_num_reqs * (self .num_speculative_tokens + 1 ) + 1 ,
84- dtype = torch .int32 ,
85- device = "cpu" ,
86- pin_memory = True )
87- self .slot_mapping = torch .zeros (self .runner .max_num_tokens ,
88- dtype = torch .int32 ,
89- device = self .device )
90- self .seq_lens_cpu = torch .zeros (self .runner .max_num_reqs *
91- (self .num_speculative_tokens + 1 ),
92- dtype = torch .int32 ,
93- device = "cpu" ,
94- pin_memory = True )
95-
9678 def load_model (self , model ) -> None :
9779 loader = get_model_loader (self .vllm_config .load_config )
9880
@@ -177,8 +159,6 @@ def dummy_run(self,
177159 # assert with_prefill is False, \
178160 # "Full decode graph only supports uniform batch now."
179161 max_seq_lens = self .runner .model_config .max_model_len
180- self .seq_lens_cpu [:num_reqs ] = max_seq_lens
181- self .seq_lens_cpu [num_reqs :] = 0
182162 if len (self .runner .attn_groups ) > 0 :
183163 num_computed_tokens_cpu = (
184164 self .runner .input_batch .
@@ -187,22 +167,24 @@ def dummy_run(self,
187167 [0 ] + self .runner .actual_seq_lengths_q [:num_reqs ],
188168 device = self .runner .device ,
189169 dtype = torch .int32 )
190- self .query_start_loc [:num_reqs + 1 ].copy_ (query_start_loc )
191170 common_attn_metadata = AscendCommonAttentionMetadata (
192- query_start_loc = self .query_start_loc [:num_reqs + 1 ],
171+ query_start_loc = torch .tensor (
172+ [0 ] + self .runner .actual_seq_lengths_q [:num_reqs ],
173+ device = self .device ,
174+ dtype = torch .int32 ),
193175 query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs +
194176 1 ],
195- seq_lens_cpu = self .seq_lens_cpu ,
196- seq_lens = self .seq_lens_cpu [:num_reqs ],
177+ seq_lens_cpu = self .runner . seq_lens_cpu ,
178+ seq_lens = self .runner . seq_lens_cpu [:num_reqs ],
197179 num_reqs = num_reqs ,
198180 num_actual_tokens = num_tokens ,
199181 max_query_len = self .num_speculative_tokens + 1 ,
200182 num_computed_tokens_cpu = num_computed_tokens_cpu ,
201183 actual_seq_lengths_q = self .runner .actual_seq_lengths_q ,
202184 block_table_tensor = self .runner .input_batch .block_table [0 ].
203185 get_device_tensor ()[:num_reqs ],
204- slot_mapping = self .slot_mapping ,
205- positions = self .positions ,
186+ slot_mapping = self .runner . input_batch . block_table [ 0 ]. slot_mapping ,
187+ positions = self .runner . positions ,
206188 attn_mask = self .runner .attn_mask ,
207189 spec_attn_mask = self .runner .spec_attn_mask ,
208190 attn_state = self .runner .attn_state ,
@@ -319,22 +301,6 @@ def generate_token_ids(self,
319301 target_hidden_states = hidden_states [:num_scheduled_tokens ]
320302 target_slot_mapping = attn_metadata .slot_mapping
321303 cu_num_tokens = attn_metadata .query_start_loc
322-
323- query_start_loc_num = len (cu_num_tokens )
324- self .query_start_loc [:query_start_loc_num ].copy_ (
325- cu_num_tokens [:query_start_loc_num ])
326- self .query_start_loc [query_start_loc_num :].fill_ (0 )
327- self .query_start_loc_cpu [:query_start_loc_num ].copy_ (
328- self .query_start_loc [:query_start_loc_num ], non_blocking = True )
329- self .query_start_loc_cpu [query_start_loc_num :].fill_ (0 )
330-
331- target_slot_mapping_len = target_slot_mapping .shape [0 ]
332- self .slot_mapping [:target_slot_mapping_len ].copy_ (
333- target_slot_mapping )
334- self .slot_mapping [target_slot_mapping_len :].fill_ (0 )
335- target_positions_len = target_positions .shape [0 ]
336- self .positions [:target_positions_len ].copy_ (target_positions )
337- self .positions [target_positions_len :].fill_ (0 )
338304 else :
339305 # TODO(woosuk): Refactor this.
340306 num_draft_tokens = spec_decode_metadata .num_draft_tokens
@@ -432,20 +398,6 @@ def _prepare_inputs(
432398 target_hidden_states = hidden_states [token_indices ]
433399 target_slot_mapping = slot_mapping [token_indices ]
434400
435- batch_size = num_rejected_tokens .shape [0 ]
436- self .query_start_loc [:batch_size + 1 ].copy_ (cu_num_tokens [:batch_size +
437- 1 ])
438- self .query_start_loc [batch_size + 1 :].fill_ (0 )
439- self .query_start_loc_cpu [:batch_size + 1 ].copy_ (
440- self .query_start_loc [:batch_size + 1 ], non_blocking = True )
441- self .query_start_loc_cpu [batch_size + 1 :].fill_ (0 )
442- target_positions_len = target_positions .shape [0 ]
443- self .positions [:target_positions_len ].copy_ (target_positions )
444- self .positions [target_positions_len :].fill_ (0 )
445- target_slot_mapping_len = target_slot_mapping .shape [0 ]
446- self .slot_mapping [:target_slot_mapping_len ].copy_ (target_slot_mapping )
447- self .slot_mapping [target_slot_mapping_len :].fill_ (0 )
448-
449401 return cu_num_tokens , token_indices , target_token_ids , target_positions , target_hidden_states , target_slot_mapping
450402
451403 def _propose (
@@ -517,8 +469,6 @@ def _propose(
517469 seq_lens = target_positions [last_token_indices ] + 1
518470 seq_lens = seq_lens .int ()
519471 seq_lens_len = seq_lens .shape [0 ]
520- self .seq_lens_cpu [:seq_lens_len ].copy_ (seq_lens , non_blocking = True )
521- self .seq_lens_cpu [seq_lens_len :].fill_ (0 )
522472
523473 if not self .torchair_graph_enabled :
524474 # torch mode need to update num_tokens_across_dp
@@ -552,18 +502,27 @@ def _propose(
552502 # Currently, if not torchair, runner.graph_pad_size will always be -1.
553503 graph_pad_size = self .runner .graph_pad_size
554504
505+ runner_slot_mapping = self .runner .input_batch .block_table [0 ].slot_mapping
506+ runner_slot_mapping [:target_slot_mapping .shape [0 ]].copy_ (target_slot_mapping )
507+ runner_slot_mapping [target_slot_mapping .shape [0 ]:num_input_tokens ].fill_ (0 )
508+
509+ # NOTE: Currently, just positions, slot_mapping, block_table and
510+ # seq_lens will be sent into MLAMetadata.
511+ # But only block_table and slot_mapping will be used, actually.
512+ # So we only fixed the block_table and slot_mapping's address.
513+ # If attention need to use other params one day, they should be fixed too.
555514 common_attn_metadata = AscendCommonAttentionMetadata (
556- query_start_loc = self . query_start_loc [:batch_size + 1 ],
557- query_start_loc_cpu = self . query_start_loc_cpu [:batch_size + 1 ],
558- seq_lens_cpu = self . seq_lens_cpu [: seq_lens_len ] ,
515+ query_start_loc = cu_num_tokens [:batch_size + 1 ],
516+ query_start_loc_cpu = cu_num_tokens [:batch_size + 1 ]. cpu () ,
517+ seq_lens_cpu = seq_lens . cpu () ,
559518 num_reqs = batch_size ,
560519 num_actual_tokens = num_tokens ,
561520 max_query_len = max_query_len ,
562521 actual_seq_lengths_q = self .runner .actual_seq_lengths_q ,
563522 block_table_tensor = self .runner .input_batch .block_table [0 ].
564523 get_device_tensor (),
565- slot_mapping = self . slot_mapping [: target_slot_mapping . shape [ 0 ]] ,
566- positions = self . positions [: target_positions . shape [ 0 ]] ,
524+ slot_mapping = runner_slot_mapping ,
525+ positions = target_positions ,
567526 attn_mask = self .runner .attn_mask ,
568527 spec_attn_mask = self .runner .spec_attn_mask ,
569528 attn_state = self .runner .attn_state ,
@@ -585,6 +544,7 @@ def _propose(
585544 attn_metadata = self .runner .attn_metadata_builder .build (
586545 0 , common_attn_metadata , self .runner .get_model ())
587546
547+ self .positions [:num_tokens ] = target_positions
588548 self .hidden_states [:num_tokens ] = target_hidden_states
589549 self .hidden_states [num_tokens :].fill_ (0 )
590550
@@ -734,7 +694,10 @@ def _propose(
734694 self .positions [:batch_size ] = clamped_positions
735695 self .hidden_states [:hidden_states .shape [0 ]] = hidden_states
736696 attn_metadata_i .slot_mapping [:batch_size ] = slot_mapping
737-
697+ if not self .torchair_graph_enabled :
698+ self .positions [batch_size :num_input_tokens ] = 0
699+ self .input_ids [batch_size :num_input_tokens ] = 0
700+ self .hidden_states [batch_size :num_input_tokens ].fill_ (0 )
738701 if attn_metadata_i .prefill is not None :
739702 attn_metadata_i .prefill .seq_lens = attn_metadata_i .seq_lens
740703 attn_metadata_i .prefill .seq_lens_list = attn_metadata_i .prefill .seq_lens .tolist (
0 commit comments