@@ -1394,25 +1394,29 @@ def _prepare_inputs(
13941394 req_indices , positions_np )
13951395 self .input_batch .block_table .commit_slot_mapping (
13961396 total_num_scheduled_tokens )
1397+
1398+ total_num_pcp_pads = 0
13971399 if self .pcp_size > 1 :
13981400 if not self .vllm_config .model_config .use_mla :
13991401 self .generate_kv_idx (scheduler_output )
14001402 tokens , position_pcp , pcp_unpad_mask = self ._update_tokens_for_pcp (
14011403 tokens )
14021404 num_scheduled_tokens = np .array (tokens , dtype = np .int32 )
14031405 total_num_scheduled_tokens = sum (num_scheduled_tokens [:num_reqs ])
1406+ total_num_pcp_pads = torch .sum (self .num_pcp_pads ).item ()
14041407 else :
14051408 position_pcp , pcp_unpad_mask = None , None
14061409 self .num_pcp_pads = self .num_pcp_pads [:num_reqs ]
14071410
1408- total_num_pcp_pads = sum (self .num_pcp_pads )
1409- max_num_scheduled_tokens = max (tokens )
1410- num_valid_tokens = np .array ([
1411- num_tokens -
1412- len (scheduler_output .scheduled_spec_decode_tokens .get (i , []))
1413- for num_tokens , i in zip (tokens , req_ids )
1414- ],
1415- dtype = np .int32 )
1411+ if not scheduler_output .scheduled_spec_decode_tokens :
1412+ num_valid_tokens = np .array (tokens , dtype = np .int32 )
1413+ else :
1414+ num_valid_tokens = np .array ([
1415+ num_tokens -
1416+ len (scheduler_output .scheduled_spec_decode_tokens .get (i , []))
1417+ for num_tokens , i in zip (tokens , req_ids )
1418+ ],
1419+ dtype = np .int32 )
14161420
14171421 if (self .use_aclgraph and total_num_scheduled_tokens
14181422 <= self .aclgraph_batch_sizes [- 1 ]):
0 commit comments