@@ -496,7 +496,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
496496 dtype = torch .int32 ,
497497 device = self .device )
498498 self .num_actual_tokens_pcp_padded = 0
499- if self .speculative_config and self .pcp_size > 1 :
499+ if self .speculative_config and self .pcp_size * self . dcp_size > 1 :
500500 self .input_ids_pcp_full = torch .zeros (self .max_num_tokens ,
501501 dtype = torch .int32 ,
502502 device = self .device )
@@ -1738,7 +1738,7 @@ def _prepare_inputs(
17381738 self .num_accepted_tokens .np [num_reqs :].fill (1 )
17391739 self .num_accepted_tokens .copy_to_gpu ()
17401740
1741- if self .speculative_config and self .pcp_size > 1 :
1741+ if self .speculative_config and self .pcp_size * self . dcp_size > 1 :
17421742 self ._generate_pcp_mtp_input (
17431743 num_reqs , scheduler_output .total_num_scheduled_tokens ,
17441744 scheduler_output .num_scheduled_tokens )
@@ -1820,28 +1820,29 @@ def _prepare_inputs(
18201820 prefill_context_parallel_metadata = long_seq_metadata ,
18211821 )
18221822
1823- if self .speculative_config and self .pcp_size > 1 :
1823+ if self .speculative_config and self .pcp_size * self . dcp_size > 1 :
18241824 # For pcp + spec decode, we flatten block_table
18251825 # to avoid irregular spec_attn_mask shape, e.g.,
18261826 # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
18271827 # ori block_table: # [d0, d1, p0, p1, p2]
18281828 # (num_reqs_d + num_reqs_p, max_num_blocks),
18291829 # flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
18301830 # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
1831- ori_query_lens = self .query_start_loc_pcp_full_cpu [1 :num_reqs + 1 ] - \
1832- self .query_start_loc_pcp_full_cpu [:num_reqs ]
1831+ ori_query_lens = self .query_start_loc_pcp_full [1 :num_reqs + 1 ] - \
1832+ self .query_start_loc_pcp_full [:num_reqs ]
18331833 num_prefill_reqs = (ori_query_lens
18341834 > self .decode_threshold ).sum ().item ()
18351835 num_decode_reqs = num_reqs - num_prefill_reqs
1836- num_decode_reqs_flatten = num_decode_reqs * self .decode_threshold
1836+ num_decode_reqs_flatten = \
1837+ ori_query_lens [:num_decode_reqs ].sum ().item ()
18371838 blk_table_tensor [
18381839 num_decode_reqs_flatten :num_decode_reqs_flatten +
18391840 num_prefill_reqs ].copy_ (
18401841 blk_table_tensor [num_decode_reqs :num_decode_reqs +
18411842 num_prefill_reqs ].clone ())
18421843 blk_table_tensor [:num_decode_reqs_flatten ].copy_ (
18431844 blk_table_tensor [:num_decode_reqs ].repeat_interleave (
1844- self . decode_threshold , dim = 0 ))
1845+ ori_query_lens [: num_decode_reqs ] , dim = 0 ))
18451846 common_attn_metadata .block_table_tensor = \
18461847 blk_table_tensor [:num_decode_reqs_flatten + num_prefill_reqs ]
18471848
@@ -2788,7 +2789,7 @@ def _build_dummy_attn_metadata(
27882789 sin = self .sin ,
27892790 prefill_context_parallel_metadata = long_seq_metadata ,
27902791 )
2791- if self .pcp_size > 1 :
2792+ if self .pcp_size * self . dcp_size > 1 :
27922793 common_attn_metadata .block_table_tensor = \
27932794 block_table_tensor [:num_reqs * self .decode_threshold ]
27942795 attn_state = AscendAttentionState .DecodeOnly
@@ -4250,8 +4251,8 @@ def _get_cp_local_seq_lens(
42504251 def _generate_pcp_metadata (self , total_num_scheduled_tokens ):
42514252 # In dummy run num_reqs == 0, update it from seq_lens
42524253 num_reqs = self .input_batch .num_reqs or self .query_lens .size (0 )
4253- num_decodes = sum (self .input_batch . num_computed_tokens_cpu [: num_reqs ]
4254- >= self . input_batch . num_prompt_tokens [: num_reqs ])
4254+ num_decodes = (self .query_lens <= self . decode_threshold ). sum (). item ()
4255+ num_prefills = num_reqs - num_decodes
42554256 num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self .pcp_size
42564257 self .num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
42574258 long_seq_metadata = None
@@ -4269,16 +4270,41 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens):
42694270 dtype = torch .int32 ,
42704271 )
42714272 # For pcp + spec decode, we flatten seq_lens
4272- # to avoid irregular spec_attn_mask shape
4273+ # to avoid irregular spec_attn_mask shape.
4274+ # Same as block_table, we flatten decode seq_lens to query_lens,
4275+ # and keep prefill seq_lens unchanged.
42734276 for decode_idx in range (self .decode_threshold ):
42744277 num_computed_tokens_of_pcp_dcp [
42754278 self .decode_threshold - 1 - decode_idx ::self .decode_threshold ] = \
42764279 self ._get_cp_local_seq_lens (
4277- torch .tensor (context_lens ),
4280+ torch .tensor (context_lens ) - decode_idx ,
42784281 self .pcp_size ,
42794282 self .dcp_size ,
42804283 self .parallel_config .cp_kv_cache_interleave_size ,
42814284 )
4285+ if self .decode_threshold > 1 :
4286+ num_computed_tokens_of_pcp_dcp_list = []
4287+ if num_decodes :
4288+ num_decodes_flatten = \
4289+ self .query_lens [:num_decodes ].sum ().item ()
4290+ if self .query_lens [:num_decodes ].min ().item (
4291+ ) == self .decode_threshold :
4292+ decode_flatten_idx = list (range (num_decodes_flatten ))
4293+ else :
4294+ decode_flatten_idx = []
4295+ for req_id in range (num_decodes ):
4296+ offset = (req_id + 1 ) * self .decode_threshold
4297+ decode_flatten_idx += \
4298+ list (range (offset - self .query_lens [req_id ], offset ))
4299+ num_computed_tokens_of_pcp_dcp_list .append (
4300+ num_computed_tokens_of_pcp_dcp [decode_flatten_idx ])
4301+ if num_prefills :
4302+ num_computed_tokens_of_pcp_dcp_list .append (
4303+ num_computed_tokens_of_pcp_dcp [
4304+ (num_decodes + 1 ) * self .decode_threshold -
4305+ 1 ::self .decode_threshold ])
4306+ num_computed_tokens_of_pcp_dcp = torch .cat (
4307+ num_computed_tokens_of_pcp_dcp_list , dim = 0 )
42824308 long_seq_metadata = AscendPrefillContextParallelMetadata (
42834309 num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded ,
42844310 num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp .
0 commit comments