2020
2121from vllm_ascend .ascend_forward_context import set_ascend_forward_context
2222from vllm_ascend .attention .attention_mask import AttentionMaskBuilder
23- from vllm_ascend .attention .attention_v1 import (AscendAttentionState ,
24- AscendMetadata )
23+ from vllm_ascend .attention .attention_v1 import AscendAttentionState
2524from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
2625from vllm_ascend .spec_decode .interface import Proposer , SpecDcodeType
2726
@@ -68,9 +67,6 @@ def __init__(self,
6867 self .hidden_size ),
6968 dtype = self .vllm_config .model_config .dtype ,
7069 device = device )
71- self .max_num_tokens = (
72- vllm_config .scheduler_config .max_num_batched_tokens )
73- self .token_arange_np = np .arange (self .max_num_tokens )
7470 # We need +1 here because the arange is used to set query_start_loc,
7571 # which has one more element than batch_size.
7672 self .arange = torch .arange (vllm_config .scheduler_config .max_num_seqs +
@@ -189,8 +185,10 @@ def generate_token_ids(self,
189185 dtype = torch .int32 ,
190186 device = self .device ,
191187 )
192- cu_num_tokens , token_indices = \
193- self ._prepare_inputs (eagle_attn_metadata , num_rejected_tokens )
188+ num_tokens = num_scheduled_tokens - sum (num_rejected_tokens )
189+ cu_num_tokens , token_indices = self ._prepare_inputs (
190+ eagle_attn_metadata .query_start_loc , num_rejected_tokens ,
191+ num_tokens )
194192 target_token_ids = self .runner .input_ids [token_indices ]
195193 target_positions = positions [token_indices ]
196194 if self .name == SpecDcodeType .EAGLE3 :
@@ -590,88 +588,60 @@ def _propose(
590588
591589 def _prepare_inputs (
592590 self ,
593- eagle_attn_metadata : AscendMetadata ,
591+ # [batch_size + 1]
592+ cu_target_query_lens : torch .Tensor ,
594593 # [batch_size]
595594 num_rejected_tokens : torch .Tensor ,
595+ num_tokens : int ,
596596 ) -> tuple [torch .Tensor , torch .Tensor ]:
597- """
598- This function is used to prepare the inputs for the spec decode.
599- It updates to the common_attn_metadata to account for the rejected
600- tokens (and newly sampled tokens). It also returns the token indices
601- of the tokens that should be fed to the speculator.
602- """
603- # E.g.
604- # common_attn_metadata.query_start_loc{_cpu}:
605- # [0, q1, q1 + q2, q1 + q2 + q3]
606- # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
607- # num_rejected_tokens: [n1, n2, n3]
608- # This function computes the intermediate values:
609- # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
610- # And returns:
611- # common_attn_metadata.query_start_loc{_cpu}:
612- # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
613- # common_attn_metadata.seq_lens{_cpu}:
614- # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
615- # token_indices: [0, 1, ..., q1 - n1 - 1,
616- # q1, q1 + 1, ..., q1 + q2 - n2 - 1,
617- # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
618- num_rejected_tokens_cpu = num_rejected_tokens .to ("cpu" )
619- cu_target_query_lens = eagle_attn_metadata .query_start_loc
620- device = eagle_attn_metadata .query_start_loc .device
621- query_start_loc_cpu = cu_target_query_lens .to ("cpu" )
622-
623- # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
624- new_query_len_per_req = (query_start_loc_cpu [1 :] -
625- query_start_loc_cpu [:- 1 ])
626- # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
627- new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu
628- new_num_tokens_per_req_np = new_num_tokens_per_req .numpy ()
629-
630- # [q1 - n1, q2 - n2, q3 - n3] ->
631- # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
632- new_query_start_loc_cpu = torch .zeros (
633- query_start_loc_cpu .shape ,
634- dtype = torch .int32 ,
635- pin_memory = is_pin_memory_available ())
636- new_query_start_loc_np = new_query_start_loc_cpu .numpy ()
637- np .cumsum (new_num_tokens_per_req_np , out = new_query_start_loc_np [1 :])
638-
639- total_num_tokens = new_query_start_loc_np [- 1 ]
640- # Example assuming num_tokens_per_req_np = [2, 4, 3]
641- # this implies that `new_query_start_locs` is:
642- # [0, 2, 6, 9] ->
643- # [0, 0, 2, 2, 2, 2, 6, 6, 6]
644- # _r1_ ____r2____ ___r3__
645- new_query_start_locs_expanded = np .repeat (new_query_start_loc_np [:- 1 ],
646- new_num_tokens_per_req_np )
647- # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
648- # [0, 1, 0, 1, 2, 3, 0, 1, 2]
649- # _r1_ ____r2____ ___r3__
650- token_offests = self .token_arange_np [:total_num_tokens ] \
651- - new_query_start_locs_expanded
652-
653- # Expand starting positions to match token pattern
654- # [0, q1, q1 + q2] ->
655- # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
656- # _r1_ _____r2_______ ___________r3____________
657- old_query_start_locs_expanded = np .repeat (
658- query_start_loc_cpu [:- 1 ].numpy (), new_num_tokens_per_req_np )
659- # Final token indices are:
660- # [0, 1, // req 1
661- # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
662- # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
663- token_indices_np = token_offests + old_query_start_locs_expanded
664- token_indices = torch .from_numpy (token_indices_np ).to (
665- device , non_blocking = True )
666-
667- # need use npu
597+ # cu_target_query_lens: [0, a, a + b, a + b + c]
598+ # num_rejected_tokens: [n1, n2, n3]
599+ # num_tokens_per_req: [a - n1, b - n2, c - n3]
600+ # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
601+ # token_indices: [0, 1, ..., a - n1 - 1,
602+ # a, a + 1, ..., a + b - n2 - 1,
603+ # a + b, a + b + 1, ..., a + b + c - n3 - 1]
604+
605+ # [0, a, a + b, a + b + c] -> [a, b, c]
668606 query_len_per_req = (cu_target_query_lens [1 :] -
669607 cu_target_query_lens [:- 1 ])
608+ # [a, b, c] -> [a - n1, b - n2, c - n3]
670609 num_tokens_per_req = query_len_per_req - num_rejected_tokens
671610
672611 # [a - n1, b - n2, c - n3] ->
673612 # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
674613 cu_num_tokens = torch .zeros_like (cu_target_query_lens )
675614 torch .cumsum (num_tokens_per_req , dim = 0 , out = cu_num_tokens [1 :])
676-
615+ token_indices = torch .empty (
616+ num_tokens ,
617+ dtype = torch .int32 ,
618+ device = cu_target_query_lens .device ,
619+ )
620+ BLOCK_SIZE = 1024
621+ self ._prepare_eagle_input_sequential (
622+ token_indices ,
623+ cu_target_query_lens ,
624+ cu_num_tokens ,
625+ block_size = BLOCK_SIZE ,
626+ )
677627 return cu_num_tokens , token_indices
628+
629+ def _prepare_eagle_input_sequential (self , out_tensor : torch .Tensor ,
630+ cu_query_lens : torch .Tensor ,
631+ cu_num_tokens : torch .Tensor , block_size : int ):
632+ device = cu_query_lens .device
633+ dtype = out_tensor .dtype
634+
635+ offsets = torch .arange (block_size , device = device , dtype = dtype )
636+ start_pos = cu_num_tokens [:- 1 ]
637+ end_pos = cu_num_tokens [1 :]
638+ num_tokens = end_pos - start_pos
639+
640+ global_indices = (start_pos .view (- 1 , 1 ) + offsets .view (1 , - 1 ))
641+ values = (cu_query_lens [:- 1 ].view (- 1 , 1 ) + offsets .view (1 , - 1 ))
642+
643+ mask = (offsets .view (1 , - 1 ) < num_tokens .view (- 1 , 1 ))
644+
645+ global_indices_flat = global_indices [mask ]
646+ values_flat = values [mask ]
647+ out_tensor [global_indices_flat ] = values_flat
0 commit comments