3636
3737class MtpProposer (Proposer ):
3838
39+ # TODO: Find out why ModuleRunner does not this explicit typing?
40+ model : Union [nn .Module , ACLGraphWrapper ]
41+
3942 def __init__ (
4043 self ,
4144 vllm_config : VllmConfig ,
@@ -145,7 +148,8 @@ def dummy_run(self,
145148 if skip_attn :
146149 attn_metadata = None
147150 elif is_running_torchair :
148- common_attn_metadata = TorchairCommonAttentionMetadata (
151+ common_attn_metadata : TorchairCommonAttentionMetadata = \
152+ TorchairCommonAttentionMetadata (
149153 num_reqs = num_reqs ,
150154 num_actual_tokens = 1 ,
151155 actual_seq_lengths_q = self .runner .actual_seq_lengths_q ,
@@ -156,24 +160,18 @@ def dummy_run(self,
156160 attn_metadata = self .runner .attn_metadata_builder .build_torchair_graph_dummy (
157161 common_attn_metadata )
158162 elif aclgraph_runtime_mode == CUDAGraphMode .FULL :
159- # assert with_prefill is False, \
160- # "Full decode graph only supports uniform batch now."
161- max_seq_lens = self .runner .model_config .max_model_len
162163 if len (self .runner .attn_groups ) > 0 :
163164 num_computed_tokens_cpu = (
164165 self .runner .input_batch .
165166 num_computed_tokens_cpu_tensor [:num_reqs ])
166- query_start_loc = torch .tensor (
167- [0 ] + self .runner .actual_seq_lengths_q [:num_reqs ],
168- device = self .runner .device ,
169- dtype = torch .int32 )
170- common_attn_metadata = AscendCommonAttentionMetadata (
167+ common_attn_metadata : AscendCommonAttentionMetadata = \
168+ AscendCommonAttentionMetadata (
171169 query_start_loc = torch .tensor (
172170 [0 ] + self .runner .actual_seq_lengths_q [:num_reqs ],
173171 device = self .device ,
174172 dtype = torch .int32 ),
175- query_start_loc_cpu = self .runner .query_start_loc_cpu [
176- :num_reqs + 1 ],
173+ query_start_loc_cpu = self .runner .
174+ query_start_loc_cpu [ :num_reqs + 1 ],
177175 seq_lens_cpu = self .runner .seq_lens_cpu ,
178176 seq_lens = self .runner .seq_lens_cpu [:num_reqs ],
179177 num_reqs = num_reqs ,
@@ -183,7 +181,8 @@ def dummy_run(self,
183181 actual_seq_lengths_q = self .runner .actual_seq_lengths_q ,
184182 block_table_tensor = self .runner .input_batch .block_table [0 ].
185183 get_device_tensor ()[:num_reqs ],
186- slot_mapping = self .runner .input_batch .block_table [0 ].slot_mapping ,
184+ slot_mapping = self .runner .input_batch .block_table [0 ].
185+ slot_mapping ,
187186 positions = self .runner .positions ,
188187 attn_mask = self .runner .attn_mask ,
189188 spec_attn_mask = self .runner .spec_attn_mask ,
@@ -466,7 +465,6 @@ def _propose(
466465
467466 seq_lens = target_positions [last_token_indices ] + 1
468467 seq_lens = seq_lens .int ()
469- seq_lens_len = seq_lens .shape [0 ]
470468
471469 if not self .torchair_graph_enabled :
472470 # torch mode need to update num_tokens_across_dp
@@ -481,10 +479,10 @@ def _propose(
481479
482480 if scheduler_output :
483481 uniform_decode = (max_query_len in list (
484- range (1 , self .num_speculative_tokens + 2 ))) and (
485- scheduler_output .total_num_scheduled_tokens ==
486- self .runner .input_batch .num_reqs *
487- (self .num_speculative_tokens + 1 ))
482+ range (1 , self .num_speculative_tokens +
483+ 2 ))) and ( scheduler_output .total_num_scheduled_tokens
484+ == self .runner .input_batch .num_reqs *
485+ (self .num_speculative_tokens + 1 ))
488486 batch_descriptor = BatchDescriptor (num_tokens = num_input_tokens ,
489487 uniform_decode = uniform_decode )
490488 else :
@@ -500,9 +498,12 @@ def _propose(
500498 # Currently, if not torchair, runner.graph_pad_size will always be -1.
501499 graph_pad_size = self .runner .graph_pad_size
502500
503- runner_slot_mapping = self .runner .input_batch .block_table [0 ].slot_mapping
504- runner_slot_mapping [:target_slot_mapping .shape [0 ]].copy_ (target_slot_mapping )
505- runner_slot_mapping [target_slot_mapping .shape [0 ]:num_input_tokens ].fill_ (0 )
501+ runner_slot_mapping = self .runner .input_batch .block_table [
502+ 0 ].slot_mapping
503+ runner_slot_mapping [:target_slot_mapping .shape [0 ]].copy_ (
504+ target_slot_mapping )
505+ runner_slot_mapping [target_slot_mapping .
506+ shape [0 ]:num_input_tokens ].fill_ (0 )
506507
507508 # NOTE: Currently, just positions, slot_mapping, block_table and
508509 # seq_lens will be sent into MLAMetadata.
0 commit comments