44import torch
55import torch .nn as nn
66import vllm .v1 .sample .rejection_sampler as rs
7+ from vllm .triton_utils import HAS_TRITON , tl , triton
78from vllm .v1 .sample .metadata import SamplingMetadata
89from vllm .v1 .sample .rejection_sampler import (RejectionSampler ,
910 apply_sampling_constraints ,
@@ -149,25 +150,36 @@ def rejection_sample(
149150 if not sampling_metadata .all_random :
150151 # Rejection sampling for greedy sampling requests.
151152 target_argmax = target_probs .argmax (dim = - 1 )
152- if min (num_draft_tokens ) == 1 and max (
153- num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
154- rejection_greedy_sample_spec_len_1_pytorch (
155- output_token_ids ,
156- draft_token_ids ,
157- target_argmax ,
158- bonus_token_ids ,
159- )
160- else :
161- rejection_greedy_sample_pytorch (
153+ if HAS_TRITON :
154+ rejection_greedy_sample_kernel [(batch_size , )](
162155 output_token_ids ,
163156 cu_num_draft_tokens ,
164157 draft_token_ids ,
165158 target_argmax ,
166159 bonus_token_ids ,
167- num_draft_tokens ,
168- max_spec_len ,
169160 is_greedy ,
161+ max_spec_len ,
170162 )
163+ else :
164+ if min (num_draft_tokens ) == 1 and max (
165+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
166+ rejection_greedy_sample_spec_len_1_pytorch (
167+ output_token_ids ,
168+ draft_token_ids ,
169+ target_argmax ,
170+ bonus_token_ids ,
171+ )
172+ else :
173+ rejection_greedy_sample_pytorch (
174+ output_token_ids ,
175+ cu_num_draft_tokens ,
176+ draft_token_ids ,
177+ target_argmax ,
178+ bonus_token_ids ,
179+ num_draft_tokens ,
180+ max_spec_len ,
181+ is_greedy ,
182+ )
171183 if sampling_metadata .all_greedy :
172184 return output_token_ids
173185
@@ -194,21 +206,37 @@ def rejection_sample(
194206 )
195207
196208 # Rejection sampling for random sampling requests.
197- rejection_random_sample_pytorch (
198- output_token_ids ,
199- cu_num_draft_tokens ,
200- draft_token_ids ,
201- draft_probs ,
202- target_probs ,
203- bonus_token_ids ,
204- recovered_token_ids ,
205- uniform_probs ,
206- is_greedy ,
207- max_spec_len ,
208- vocab_size ,
209- IS_NGRAM = draft_probs is None ,
210- # num_warps=1,
211- )
209+ if HAS_TRITON :
210+ rejection_random_sample_kernel [(batch_size , )](
211+ output_token_ids ,
212+ cu_num_draft_tokens ,
213+ draft_token_ids ,
214+ draft_probs ,
215+ target_probs ,
216+ bonus_token_ids ,
217+ recovered_token_ids ,
218+ uniform_probs ,
219+ is_greedy ,
220+ max_spec_len ,
221+ vocab_size ,
222+ NO_DRAFT_PROBS = draft_probs is None ,
223+ )
224+ else :
225+ rejection_random_sample_pytorch (
226+ output_token_ids ,
227+ cu_num_draft_tokens ,
228+ draft_token_ids ,
229+ draft_probs ,
230+ target_probs ,
231+ bonus_token_ids ,
232+ recovered_token_ids ,
233+ uniform_probs ,
234+ is_greedy ,
235+ max_spec_len ,
236+ vocab_size ,
237+ IS_NGRAM = draft_probs is None ,
238+ # num_warps=1,
239+ )
212240 return output_token_ids
213241
214242
@@ -241,14 +269,24 @@ def expand_batch_to_tokens(
241269 batch_size = x .shape [0 ]
242270 assert cu_num_tokens .shape [0 ] == batch_size
243271 expanded_x = x .new_empty (num_tokens )
244- expand_pytorch (
245- expanded_x ,
246- x ,
247- cu_num_tokens ,
248- replace_from ,
249- replace_to ,
250- MAX_NUM_TOKENS = MAX_SPEC_LEN , # To avoid recompilation.
251- )
272+ if HAS_TRITON :
273+ expand_kernel [(batch_size , )](
274+ expanded_x ,
275+ x ,
276+ cu_num_tokens ,
277+ replace_from ,
278+ replace_to ,
279+ MAX_NUM_TOKENS = MAX_SPEC_LEN , # To avoid recompilation.
280+ )
281+ else :
282+ expand_pytorch (
283+ expanded_x ,
284+ x ,
285+ cu_num_tokens ,
286+ replace_from ,
287+ replace_to ,
288+ MAX_NUM_TOKENS = MAX_SPEC_LEN , # To avoid recompilation.
289+ )
252290 return expanded_x
253291
254292
@@ -282,16 +320,29 @@ def sample_recovered_tokens(
282320 q [i ].exponential_ (generator = generator )
283321
284322 recovered_token_ids = torch .empty_like (draft_token_ids )
285- sample_recovered_tokens_pytorch (
286- recovered_token_ids ,
287- cu_num_draft_tokens ,
288- draft_token_ids ,
289- draft_probs ,
290- target_probs ,
291- q ,
292- vocab_size ,
293- IS_NGRAM = draft_probs is None ,
294- )
323+ if HAS_TRITON :
324+ sample_recovered_tokens_kernel [(batch_size , max_spec_len )](
325+ recovered_token_ids ,
326+ cu_num_draft_tokens ,
327+ draft_token_ids ,
328+ draft_probs ,
329+ target_probs ,
330+ q ,
331+ vocab_size ,
332+ triton .next_power_of_2 (vocab_size ),
333+ NO_DRAFT_PROBS = draft_probs is None ,
334+ )
335+ else :
336+ sample_recovered_tokens_pytorch (
337+ recovered_token_ids ,
338+ cu_num_draft_tokens ,
339+ draft_token_ids ,
340+ draft_probs ,
341+ target_probs ,
342+ q ,
343+ vocab_size ,
344+ IS_NGRAM = draft_probs is None ,
345+ )
295346 return recovered_token_ids
296347
297348
@@ -504,4 +555,192 @@ def sample_recovered_tokens_pytorch(
504555 target_probs [token_idx , draft_token_id ] = orig_prob
505556
506557
558+ @triton .jit (do_not_specialize = ["max_spec_len" ])
559+ def rejection_greedy_sample_kernel (
560+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
561+ cu_num_draft_tokens_ptr , # [batch_size]
562+ draft_token_ids_ptr , # [num_tokens]
563+ target_argmax_ptr , # [num_tokens]
564+ bonus_token_ids_ptr , # [batch_size]
565+ is_greedy_ptr , # [batch_size] or None
566+ max_spec_len ,
567+ ):
568+ req_idx = tl .program_id (0 )
569+ # Because is_greedy_ptr is not Nonr at profiling run,
570+ # re-comilation may happen during runtime when is_greedy_ptr is None.
571+ is_greedy = True if is_greedy_ptr is None else tl .load (is_greedy_ptr +
572+ req_idx )
573+ if not is_greedy :
574+ # Early exit for non-greedy sampling requests
575+ return
576+
577+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
578+ req_idx - 1 )
579+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
580+ num_draft_tokens = end_idx - start_idx
581+
582+ rejected = False
583+ for pos in range (num_draft_tokens ):
584+ if not rejected :
585+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
586+ target_argmax_id = tl .load (target_argmax_ptr + start_idx + pos )
587+ tl .store (
588+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
589+ target_argmax_id ,
590+ )
591+ if draft_token_id != target_argmax_id :
592+ # Reject
593+ rejected = True
594+
595+ if not rejected :
596+ # If all tokens are accepted, append the bonus token
597+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
598+ tl .store (
599+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
600+ num_draft_tokens ,
601+ bonus_token_id ,
602+ )
603+
604+
605+ @triton .jit (do_not_specialize = ["max_spec_len" ])
606+ def rejection_random_sample_kernel (
607+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
608+ cu_num_draft_tokens_ptr , # [batch_size]
609+ draft_token_ids_ptr , # [num_tokens]
610+ draft_probs_ptr , # [num_tokens, vocab_size] or None
611+ target_probs_ptr , # [num_tokens, vocab_size]
612+ bonus_token_ids_ptr , # [batch_size]
613+ recovered_token_ids_ptr , # [num_tokens]
614+ uniform_probs_ptr , # [num_tokens]
615+ is_greedy_ptr , # [batch_size]
616+ max_spec_len ,
617+ vocab_size ,
618+ NO_DRAFT_PROBS : tl .constexpr ,
619+ ):
620+ req_idx = tl .program_id (0 )
621+ is_greedy = tl .load (is_greedy_ptr + req_idx )
622+ if is_greedy :
623+ # Early exost for greedy sampling requests
624+ return
625+
626+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
627+ req_idx - 1 )
628+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
629+ num_draft_tokens = end_idx - start_idx
630+
631+ rejected = False
632+ for pos in range (num_draft_tokens ):
633+ if not rejected :
634+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
635+ if NO_DRAFT_PROBS :
636+ draft_prob = 1
637+ else :
638+ draft_prob = tl .load (draft_probs_ptr +
639+ (start_idx + pos ) * vocab_size +
640+ draft_token_id )
641+ target_prob = tl .load (target_probs_ptr +
642+ (start_idx + pos ) * vocab_size +
643+ draft_token_id )
644+ uniform_prob = tl .load (uniform_probs_ptr + start_idx + pos )
645+ if draft_prob > 0 and target_prob / draft_prob >= uniform_prob :
646+ # Accept
647+ token_id = draft_token_id
648+ else :
649+ # Reject. Use recovered token
650+ rejected = True
651+ token_id = tl .load (recovered_token_ids_ptr + start_idx + pos )
652+ tl .store (output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
653+ token_id )
654+
655+ if not rejected :
656+ # If all tokens are accepted, append the bonus token
657+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
658+ tl .store (
659+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
660+ num_draft_tokens ,
661+ bonus_token_id ,
662+ )
663+
664+
665+ @triton .jit (do_not_specialize = ["replace_from" , "replace_to" ])
666+ def expand_kernel (
667+ output_ptr , # [num_tokens]
668+ input_ptr , # [batch_size]
669+ cu_num_tokens_ptr , # [batch_size]
670+ replace_from ,
671+ replace_to ,
672+ MAX_NUM_TOKENS : tl .constexpr ,
673+ ):
674+ req_idx = tl .program_id (0 )
675+ if req_idx == 0 :
676+ start_idx = 0
677+ else :
678+ start_idx = tl .load (cu_num_tokens_ptr + req_idx - 1 )
679+ end_idx = tl .load (cu_num_tokens_ptr + req_idx )
680+ num_tokens = end_idx - start_idx
681+
682+ src_val = tl .load (input_ptr + req_idx )
683+ src_val = tl .where (src_val == replace_from , replace_to , src_val )
684+ offset = tl .arange (0 , MAX_NUM_TOKENS )
685+ tl .store (output_ptr + start_idx + offset ,
686+ src_val ,
687+ mask = offset < num_tokens )
688+
689+
690+ @triton .jit
691+ def sample_recovered_tokens_kernel (
692+ output_token_ids_ptr , # [num_tokens]
693+ cu_num_draft_tokens_ptr , # [batch_size]
694+ draft_token_ids_ptr , # [num_tokens]
695+ draft_probs_ptr , # [num_tokens, vocab_size] or None
696+ target_probs_ptr , # [num_tokens, vocab_size]
697+ q_ptr , # [batch_size, vocab_size]
698+ vocab_size ,
699+ PADDED_VOCAB_SIZE : tl .constexpr ,
700+ NO_DRAFT_PROBS : tl .constexpr ,
701+ ):
702+ req_idx = tl .program_id (0 )
703+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
704+ req_idx - 1 )
705+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
706+ num_draft_tokens = end_idx - start_idx
707+
708+ # Early exit for out-of-range positions
709+ pos = tl .program_id (1 )
710+ if pos >= num_draft_tokens :
711+ return
712+
713+ vocab_offset = tl .arange (0 , PADDED_VOCAB_SIZE )
714+ if NO_DRAFT_PROBS :
715+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
716+ prob = tl .load (
717+ target_probs_ptr + (start_idx + pos ) * vocab_size + vocab_offset ,
718+ mask = ((vocab_offset < vocab_size ) &
719+ (vocab_offset != draft_token_id )),
720+ other = 0 ,
721+ )
722+ else :
723+ draft_prob = tl .load (
724+ draft_probs_ptr + (start_idx + pos ) * vocab_size + vocab_offset ,
725+ mask = vocab_offset < vocab_size ,
726+ other = 0 ,
727+ )
728+ target_prob = tl .load (
729+ target_probs_ptr + (start_idx + pos ) * vocab_size + vocab_offset ,
730+ mask = vocab_offset < vocab_size ,
731+ other = 0 ,
732+ )
733+ prob = tl .maximum (target_prob - draft_prob , 0 )
734+ # We don't need `prob = prob / tl.sum(prob)` here because
735+ # `tl.argmax` will select the maximum value.
736+
737+ q = tl .load (
738+ q_ptr + req_idx * vocab_size + vocab_offset ,
739+ mask = vocab_offset < vocab_size ,
740+ other = float ("-inf" ),
741+ )
742+ recovered_id = tl .argmax (prob / q , axis = - 1 )
743+ tl .store (output_token_ids_ptr + start_idx + pos , recovered_id )
744+
745+
507746rs .expand_batch_to_tokens = expand_batch_to_tokens
0 commit comments