@@ -151,15 +151,36 @@ def rejection_sample(
151151 # Rejection sampling for greedy sampling requests.
152152 target_argmax = target_probs .argmax (dim = - 1 )
153153 if HAS_TRITON :
154- rejection_greedy_sample_kernel [(batch_size , )](
155- output_token_ids ,
156- cu_num_draft_tokens ,
157- draft_token_ids ,
158- target_argmax ,
159- bonus_token_ids ,
160- is_greedy ,
161- max_spec_len ,
162- )
154+ vec_len = batch_size
155+ n = cu_num_draft_tokens .numel ()
156+ BLOCK_SIZE = 2
157+ grid = triton .cdiv (n , BLOCK_SIZE )
158+ if n >= 40 :
159+ grid = 40
160+ BLOCK_SIZE = triton .next_power_of_2 (n // grid )
161+
162+ if min (num_draft_tokens ) == 1 and max (
163+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
164+ rejection_greedy_sample_spec_len_1_triton [(grid , )](
165+ output_token_ids ,
166+ draft_token_ids ,
167+ target_argmax ,
168+ bonus_token_ids ,
169+ vec_len ,
170+ BLOCK_SIZE = BLOCK_SIZE ,
171+ )
172+ else :
173+ rejection_greedy_sample_triton [(grid , )](
174+ output_token_ids ,
175+ cu_num_draft_tokens ,
176+ draft_token_ids ,
177+ target_argmax ,
178+ bonus_token_ids ,
179+ is_greedy ,
180+ vec_len ,
181+ max_spec_len ,
182+ BLOCK_SIZE = BLOCK_SIZE ,
183+ )
163184 else :
164185 if min (num_draft_tokens ) == 1 and max (
165186 num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
@@ -270,13 +291,23 @@ def expand_batch_to_tokens(
270291 assert cu_num_tokens .shape [0 ] == batch_size
271292 expanded_x = x .new_empty (num_tokens )
272293 if HAS_TRITON :
273- expand_kernel [(batch_size , )](
294+ vec_len = batch_size
295+ n = cu_num_tokens .numel ()
296+ BLOCK_SIZE = 2
297+ grid = triton .cdiv (n , BLOCK_SIZE )
298+ if n >= 40 :
299+ grid = 40
300+ BLOCK_SIZE = triton .next_power_of_2 (n // grid )
301+
302+ expand_kernel [(grid , )](
274303 expanded_x ,
275304 x ,
276305 cu_num_tokens ,
277306 replace_from ,
278307 replace_to ,
308+ vec_len ,
279309 MAX_NUM_TOKENS = MAX_SPEC_LEN , # To avoid recompilation.
310+ BLOCK_SIZE = BLOCK_SIZE ,
280311 )
281312 else :
282313 expand_pytorch (
@@ -559,50 +590,115 @@ def sample_recovered_tokens_pytorch(
559590
560591
561592@triton .jit (do_not_specialize = ["max_spec_len" ])
562- def rejection_greedy_sample_kernel (
593+ def bonus_renew_1 (
594+ bonus_token_ids_ptr ,
595+ position ,
596+ output_token_ids_ptr ,
597+ ):
598+ bonus_token_id = tl .load (bonus_token_ids_ptr + position )
599+ tl .store (output_token_ids_ptr + position * 2 + 1 , bonus_token_id )
600+
601+
602+ @triton .jit (do_not_specialize = ["max_spec_len" ])
603+ def rejection_greedy_sample_spec_len_1_triton (
604+ output_token_ids_ptr , # [batch_size, 2]
605+ draft_token_ids_ptr , # [num_tokens]
606+ target_argmax_ptr , # [num_tokens]
607+ bonus_token_ids_ptr ,
608+ vec_len ,
609+ BLOCK_SIZE : tl .constexpr ,
610+ ):
611+ block_idx = tl .program_id (0 )
612+ offset = block_idx * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
613+ mask = offset < vec_len
614+
615+ draft_token_id = tl .load (draft_token_ids_ptr + offset , mask )
616+ target_argmax_id = tl .load (target_argmax_ptr + offset , mask )
617+ tl .store (output_token_ids_ptr + offset * 2 , target_argmax_id , mask )
618+
619+ for pos in tl .arange (0 , BLOCK_SIZE ):
620+ draft_token_id1 = tl .get_element (draft_token_id , (pos , ))
621+ target_argmax1 = tl .get_element (target_argmax_id , (pos , ))
622+ position = block_idx * BLOCK_SIZE + pos
623+ if draft_token_id1 == target_argmax1 :
624+ bonus_renew_1 (
625+ bonus_token_ids_ptr ,
626+ position ,
627+ output_token_ids_ptr ,
628+ )
629+
630+
631+ @triton .jit (do_not_specialize = ["max_spec_len" ])
632+ def bonus_renew (
633+ bonus_token_ids_ptr ,
634+ position ,
635+ output_token_ids_ptr ,
636+ max_spec_len ,
637+ num_tokens1 ,
638+ ):
639+ bonus_token_id = tl .load (bonus_token_ids_ptr + position )
640+ tl .store (
641+ output_token_ids_ptr + position * (max_spec_len + 1 ) + num_tokens1 ,
642+ bonus_token_id )
643+
644+
645+ @triton .jit (do_not_specialize = ["max_spec_len" ])
646+ def rejection_greedy_sample_triton (
563647 output_token_ids_ptr , # [batch_size, max_spec_len + 1]
564648 cu_num_draft_tokens_ptr , # [batch_size]
565649 draft_token_ids_ptr , # [num_tokens]
566650 target_argmax_ptr , # [num_tokens]
567651 bonus_token_ids_ptr , # [batch_size]
568652 is_greedy_ptr , # [batch_size] or None
653+ vec_len ,
569654 max_spec_len ,
655+ BLOCK_SIZE : tl .constexpr ,
570656):
571- req_idx = tl .program_id (0 )
572- # Because is_greedy_ptr is not Nonr at profiling run,
573- # re-comilation may happen during runtime when is_greedy_ptr is None.
574- is_greedy = True if is_greedy_ptr is None else tl .load (is_greedy_ptr +
575- req_idx )
576- if not is_greedy :
577- # Early exit for non-greedy sampling requests
578- return
657+ block_idx = tl .program_id (0 )
658+ offset = block_idx * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
659+ mask = offset < vec_len
579660
580- start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
581- req_idx - 1 )
582- end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
661+ if is_greedy_ptr is None :
662+ is_greedy_mask = mask
663+ else :
664+ is_greedy = tl .load (is_greedy_ptr + offset , mask = mask , other = 0 )
665+ is_greedy_mask = mask & (is_greedy != 0 )
666+
667+ start_idx = tl .where (
668+ offset == 0 , 0 ,
669+ tl .load (cu_num_draft_tokens_ptr + offset - 1 , is_greedy_mask ))
670+ end_idx = tl .load (cu_num_draft_tokens_ptr + offset , is_greedy_mask )
583671 num_draft_tokens = end_idx - start_idx
584672
585- rejected = False
586- for pos in range (num_draft_tokens ):
587- if not rejected :
588- draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
589- target_argmax_id = tl .load (target_argmax_ptr + start_idx + pos )
590- tl .store (
591- output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
592- target_argmax_id ,
593- )
594- if draft_token_id != target_argmax_id :
595- # Reject
596- rejected = True
673+ for pos in tl .range (0 , BLOCK_SIZE ):
674+ num_tokens1 = tl .get_element (num_draft_tokens , (pos , ))
675+ if num_tokens1 != 0 :
676+ rejected = False
677+ start_idx1 = tl .get_element (start_idx , (pos , ))
678+ position = block_idx * BLOCK_SIZE + pos
679+ for i in range (num_tokens1 ):
680+ if not rejected :
681+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx1 +
682+ i )
683+ target_argmax_id = tl .load (target_argmax_ptr + start_idx1 +
684+ i )
685+ tl .store (
686+ output_token_ids_ptr + position * (max_spec_len + 1 ) +
687+ i ,
688+ target_argmax_id ,
689+ )
690+ if draft_token_id != target_argmax_id :
691+ # Reject.
692+ rejected = True
597693
598- if not rejected :
599- # If all tokens are accepted, append the bonus token
600- bonus_token_id = tl . load ( bonus_token_ids_ptr + req_idx )
601- tl . store (
602- output_token_ids_ptr + req_idx * ( max_spec_len + 1 ) +
603- num_draft_tokens ,
604- bonus_token_id ,
605- )
694+ if not rejected :
695+ bonus_renew (
696+ bonus_token_ids_ptr ,
697+ position ,
698+ output_token_ids_ptr ,
699+ max_spec_len ,
700+ num_tokens1 ,
701+ )
606702
607703
608704@triton .jit (do_not_specialize = ["max_spec_len" ])
@@ -672,22 +768,30 @@ def expand_kernel(
672768 cu_num_tokens_ptr , # [batch_size]
673769 replace_from ,
674770 replace_to ,
771+ vec_len ,
675772 MAX_NUM_TOKENS : tl .constexpr ,
773+ BLOCK_SIZE : tl .constexpr ,
676774):
677775 req_idx = tl .program_id (0 )
678- if req_idx == 0 :
679- start_idx = 0
680- else :
681- start_idx = tl .load (cu_num_tokens_ptr + req_idx - 1 )
682- end_idx = tl .load (cu_num_tokens_ptr + req_idx )
776+ offset = req_idx * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
777+ len_mask = offset < vec_len
778+
779+ start_idx = tl .where (offset == 0 , 0 ,
780+ tl .load (cu_num_tokens_ptr + offset - 1 , len_mask ))
781+ end_idx = tl .load (cu_num_tokens_ptr + offset , len_mask )
683782 num_tokens = end_idx - start_idx
684783
685- src_val = tl .load (input_ptr + req_idx )
784+ src_val = tl .load (input_ptr + offset , len_mask )
686785 src_val = tl .where (src_val == replace_from , replace_to , src_val )
687- offset = tl .arange (0 , MAX_NUM_TOKENS )
688- tl .store (output_ptr + start_idx + offset ,
689- src_val ,
690- mask = offset < num_tokens )
786+
787+ for i in tl .range (0 , BLOCK_SIZE ):
788+ num_tokens1 = tl .get_element (num_tokens , (i , ))
789+ start_idx1 = tl .get_element (start_idx , (i , ))
790+ src_val1 = tl .get_element (src_val , (i , ))
791+ offset1 = tl .arange (0 , MAX_NUM_TOKENS )
792+ tl .store (output_ptr + start_idx1 + offset1 ,
793+ src_val1 ,
794+ mask = offset1 < num_tokens1 )
691795
692796
693797@triton .jit
0 commit comments