@@ -412,33 +412,21 @@ def _propose(
412412 block_table = block_table .cpu ()
413413 num_tokens = target_token_ids .shape [0 ]
414414 batch_size = next_token_ids .shape [0 ]
415+ last_token_indices = cu_num_tokens [1 :] - 1
415416 target_positions = target_positions .cpu ()
416417 if self .name == SpecDcodeType .EAGLE3 :
417418 assert isinstance (self .model , Eagle3LlamaForCausalLM )
418419 target_hidden_states = self .model .combine_hidden_states (
419420 target_hidden_states )
420421 assert target_hidden_states .shape [- 1 ] == self .hidden_size
421422
422- first_token_indices = cu_num_tokens [:- 1 ]
423- last_token_indices = cu_num_tokens [1 :] - 1
424-
425423 # Shift the input ids by one token.
426424 # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
427425 self .input_ids [:num_tokens - 1 ] = target_token_ids [1 :]
428426 # Replace the last token with the next token.
429427 # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
430428 self .input_ids [last_token_indices ] = next_token_ids
431- if self .runner .attn_state == AscendAttentionState .PrefillNoCache :
432- prefill_seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
433- decode_seq_lens = prefill_seq_lens
434- elif self .runner .attn_state == AscendAttentionState .ChunkedPrefill :
435- prefill_seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
436- decode_seq_lens = prefill_seq_lens
437- elif self .runner .attn_state == AscendAttentionState .DecodeOnly :
438- prefill_seq_lens = (target_positions [first_token_indices ]).int ()
439- decode_seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
440- else :
441- raise NotImplementedError ("This attention state is not implemented!" )
429+ seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
442430
443431 query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
444432 max_query_len = query_lens .max ().item ()
@@ -447,7 +435,7 @@ def _propose(
447435 common_attn_metadata = AscendCommonAttentionMetadata (
448436 query_start_loc = cu_num_tokens .to (device ),
449437 query_start_loc_cpu = cu_num_tokens ,
450- seq_lens_cpu = prefill_seq_lens .cpu (),
438+ seq_lens_cpu = seq_lens .cpu (),
451439 max_query_len = max_query_len ,
452440 num_reqs = batch_size ,
453441 num_actual_tokens = num_tokens ,
@@ -517,11 +505,15 @@ def _propose(
517505 attn_metadata .num_actual_tokens = batch_size
518506 attn_metadata .max_query_len = 1
519507 attn_metadata .query_start_loc = self .arange [:batch_size + 1 ]
508+ attn_metadata .query_start_loc_list = attn_metadata .query_start_loc [
509+ 1 :].tolist ()
510+ attn_metadata .num_decodes , attn_metadata .num_prefills , attn_metadata .num_decode_tokens , attn_metadata .num_prefill_tokens = 0 , batch_size , 0 , batch_size
511+ attn_metadata .num_actual_tokens_pcp_padded = attn_metadata .num_decode_tokens + attn_metadata .num_prefill_tokens
520512 query_lens .fill_ (1 )
521513 attn_metadata .query_lens = query_lens
522514
523- attn_metadata .actual_seq_lengths_q = [1 for _ in attn_metadata . actual_seq_lengths_q ]
524- attn_metadata .seq_lens_list = decode_seq_lens .tolist ()
515+ attn_metadata .actual_seq_lengths_q = [1 + i for i in range ( batch_size ) ]
516+ attn_metadata .seq_lens_list = seq_lens .tolist ()
525517 attn_metadata .attn_state = AscendAttentionState .ChunkedPrefill
526518 for now_speculative in range (
527519 self .vllm_config .speculative_config .num_speculative_tokens -
@@ -548,7 +540,9 @@ def _propose(
548540 # TODO: Increment the sequence lengths.
549541
550542 attn_metadata .seq_lens += 1
551- attn_metadata .seq_lens_list = [_ + 1 for _ in attn_metadata .seq_lens_list ]
543+ attn_metadata .seq_lens_list = [
544+ _ + 1 for _ in attn_metadata .seq_lens_list
545+ ]
552546 # TODO: Consider max model length.
553547 # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
554548 # self.max_model_len)
0 commit comments