99 apply_sampling_constraints ,
1010 generate_uniform_probs )
1111from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
12+ import triton
13+ import triton .language as tl
14+ import triton .runtime .driver as driver
1215
16+ import torch_npu ._inductor
1317PLACEHOLDER_TOKEN_ID = - 1
1418GREEDY_TEMPERATURE = - 1
1519# Maximum number of speculative draft tokens allowed per request in a single
@@ -102,6 +106,189 @@ def forward(
102106 )
103107 return output_token_ids
104108
109+ # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
110+ @triton .jit (do_not_specialize = ["max_spec_len" ])
111+ def rejection_greedy_sample_kernel (
112+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
113+ cu_num_draft_tokens_ptr , # [batch_size]
114+ draft_token_ids_ptr , # [num_tokens]
115+ target_argmax_ptr , # [num_tokens]
116+ bonus_token_ids_ptr , # [batch_size]
117+ is_greedy_ptr , # [batch_size] or None
118+ max_spec_len ,
119+ ):
120+ req_idx = tl .program_id (0 )
121+ # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
122+ # re-compilation may happen during runtime when is_greedy_ptr is None.
123+ is_greedy = True if is_greedy_ptr is None else tl .load (is_greedy_ptr + req_idx )
124+ if not is_greedy :
125+ # Early exit for non-greedy sampling requests.
126+ return
127+
128+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr + req_idx - 1 )
129+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
130+ num_draft_tokens = end_idx - start_idx
131+
132+ rejected = False
133+ for pos in range (num_draft_tokens ):
134+ if not rejected :
135+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
136+ target_argmax_id = tl .load (target_argmax_ptr + start_idx + pos )
137+ tl .store (
138+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
139+ target_argmax_id ,
140+ )
141+ if draft_token_id != target_argmax_id :
142+ # Reject.
143+ rejected = True
144+
145+ if not rejected :
146+ # If all tokens are accepted, append the bonus token.
147+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
148+ tl .store (
149+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + num_draft_tokens ,
150+ bonus_token_id ,
151+ )
152+
153+
154+ # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
155+ @triton .jit (do_not_specialize = ["max_spec_len" ])
156+ def rejection_random_sample_kernel (
157+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
158+ cu_num_draft_tokens_ptr , # [batch_size]
159+ draft_token_ids_ptr , # [num_tokens]
160+ draft_probs_ptr , # [num_tokens, vocab_size] or None
161+ target_probs_ptr , # [num_tokens, vocab_size]
162+ bonus_token_ids_ptr , # [batch_size]
163+ recovered_token_ids_ptr , # [num_tokens]
164+ uniform_probs_ptr , # [num_tokens]
165+ is_greedy_ptr , # [batch_size]
166+ max_spec_len ,
167+ vocab_size ,
168+ NO_DRAFT_PROBS : tl .constexpr ,
169+ ):
170+ req_idx = tl .program_id (0 )
171+ is_greedy = tl .load (is_greedy_ptr + req_idx )
172+ if is_greedy :
173+ # Early exit for greedy sampling requests.
174+ return
175+
176+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr + req_idx - 1 )
177+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
178+ num_draft_tokens = end_idx - start_idx
179+
180+ rejected = False
181+ for pos in range (num_draft_tokens ):
182+ if not rejected :
183+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
184+ if NO_DRAFT_PROBS :
185+ draft_prob = 1
186+ else :
187+ draft_prob = tl .load (
188+ draft_probs_ptr + (start_idx + pos ) * vocab_size + draft_token_id
189+ )
190+ target_prob = tl .load (
191+ target_probs_ptr + (start_idx + pos ) * vocab_size + draft_token_id
192+ )
193+ uniform_prob = tl .load (uniform_probs_ptr + start_idx + pos )
194+ # NOTE(woosuk): While the draft probability should never be 0,
195+ # we check it to avoid NaNs. If it happens to be 0, we reject.
196+ if draft_prob > 0 and target_prob / draft_prob >= uniform_prob :
197+ # Accept.
198+ token_id = draft_token_id
199+ else :
200+ # Reject. Use recovered token.
201+ rejected = True
202+ token_id = tl .load (recovered_token_ids_ptr + start_idx + pos )
203+ tl .store (
204+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos , token_id
205+ )
206+
207+ if not rejected :
208+ # If all tokens are accepted, append the bonus token.
209+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
210+ tl .store (
211+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + num_draft_tokens ,
212+ bonus_token_id ,
213+ )
214+
215+
216+ # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
217+ @triton .jit (do_not_specialize = ["replace_from" , "replace_to" ])
218+ def expand_kernel (
219+ output_ptr , # [num_tokens]
220+ input_ptr , # [batch_size]
221+ cu_num_tokens_ptr , # [batch_size]
222+ replace_from ,
223+ replace_to ,
224+ MAX_NUM_TOKENS : tl .constexpr ,
225+ ):
226+ req_idx = tl .program_id (0 )
227+ if req_idx == 0 : # noqa: SIM108
228+ start_idx = 0
229+ else :
230+ start_idx = tl .load (cu_num_tokens_ptr + req_idx - 1 )
231+ end_idx = tl .load (cu_num_tokens_ptr + req_idx )
232+ num_tokens = end_idx - start_idx
233+
234+ src_val = tl .load (input_ptr + req_idx )
235+ src_val = tl .where (src_val == replace_from , replace_to , src_val )
236+ offset = tl .arange (0 , MAX_NUM_TOKENS )
237+ tl .store (output_ptr + start_idx + offset , src_val , mask = offset < num_tokens )
238+
239+
240+ @triton .jit
241+ def sample_recovered_tokens_kernel (
242+ output_token_ids_ptr , # [num_tokens]
243+ cu_num_draft_tokens_ptr , # [batch_size]
244+ draft_token_ids_ptr , # [num_tokens]
245+ draft_probs_ptr , # [num_tokens, vocab_size] or None
246+ target_probs_ptr , # [num_tokens, vocab_size]
247+ q_ptr , # [batch_size, vocab_size]
248+ vocab_size ,
249+ PADDED_VOCAB_SIZE : tl .constexpr ,
250+ NO_DRAFT_PROBS : tl .constexpr ,
251+ ):
252+ req_idx = tl .program_id (0 )
253+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr + req_idx - 1 )
254+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
255+ num_draft_tokens = end_idx - start_idx
256+
257+ # Early exit for out-of-range positions.
258+ pos = tl .program_id (1 )
259+ if pos >= num_draft_tokens :
260+ return
261+
262+ vocab_offset = tl .arange (0 , PADDED_VOCAB_SIZE )
263+ if NO_DRAFT_PROBS :
264+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
265+ prob = tl .load (
266+ target_probs_ptr + (start_idx + pos ) * vocab_size + vocab_offset ,
267+ mask = ((vocab_offset < vocab_size ) & (vocab_offset != draft_token_id )),
268+ other = 0 ,
269+ )
270+ else :
271+ draft_prob = tl .load (
272+ draft_probs_ptr + (start_idx + pos ) * vocab_size + vocab_offset ,
273+ mask = vocab_offset < vocab_size ,
274+ other = 0 ,
275+ )
276+ target_prob = tl .load (
277+ target_probs_ptr + (start_idx + pos ) * vocab_size + vocab_offset ,
278+ mask = vocab_offset < vocab_size ,
279+ other = 0 ,
280+ )
281+ prob = tl .maximum (target_prob - draft_prob , 0 )
282+ # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
283+ # `tl.argmax` will select the maximum value.
284+
285+ q = tl .load (
286+ q_ptr + req_idx * vocab_size + vocab_offset ,
287+ mask = vocab_offset < vocab_size ,
288+ other = float ("-inf" ),
289+ )
290+ recovered_id = tl .argmax (prob / q , axis = - 1 )
291+ tl .store (output_token_ids_ptr + start_idx + pos , recovered_id )
105292
106293def rejection_sample (
107294 # [num_tokens]
@@ -149,25 +336,15 @@ def rejection_sample(
149336 if not sampling_metadata .all_random :
150337 # Rejection sampling for greedy sampling requests.
151338 target_argmax = target_probs .argmax (dim = - 1 )
152- if min (num_draft_tokens ) == 1 and max (
153- num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
154- rejection_greedy_sample_spec_len_1_pytorch (
155- output_token_ids ,
156- draft_token_ids ,
157- target_argmax ,
158- bonus_token_ids ,
159- )
160- else :
161- rejection_greedy_sample_pytorch (
162- output_token_ids ,
163- cu_num_draft_tokens ,
164- draft_token_ids ,
165- target_argmax ,
166- bonus_token_ids ,
167- num_draft_tokens ,
168- max_spec_len ,
169- is_greedy ,
170- )
339+ rejection_greedy_sample_kernel [(batch_size ,)](
340+ output_token_ids ,
341+ cu_num_draft_tokens ,
342+ draft_token_ids ,
343+ target_argmax ,
344+ bonus_token_ids ,
345+ is_greedy ,
346+ max_spec_len ,
347+ )
171348 if sampling_metadata .all_greedy :
172349 return output_token_ids
173350
@@ -194,7 +371,8 @@ def rejection_sample(
194371 )
195372
196373 # Rejection sampling for random sampling requests.
197- rejection_random_sample_pytorch (
374+ # Rejection sampling for random sampling requests.
375+ rejection_random_sample_kernel [(batch_size ,)](
198376 output_token_ids ,
199377 cu_num_draft_tokens ,
200378 draft_token_ids ,
@@ -206,8 +384,7 @@ def rejection_sample(
206384 is_greedy ,
207385 max_spec_len ,
208386 vocab_size ,
209- IS_NGRAM = draft_probs is None ,
210- # num_warps=1,
387+ NO_DRAFT_PROBS = draft_probs is None ,
211388 )
212389 return output_token_ids
213390
@@ -241,7 +418,7 @@ def expand_batch_to_tokens(
241418 batch_size = x .shape [0 ]
242419 assert cu_num_tokens .shape [0 ] == batch_size
243420 expanded_x = x .new_empty (num_tokens )
244- expand_pytorch (
421+ expand_kernel [( batch_size ,)] (
245422 expanded_x ,
246423 x ,
247424 cu_num_tokens ,
@@ -282,15 +459,16 @@ def sample_recovered_tokens(
282459 q [i ].exponential_ (generator = generator )
283460
284461 recovered_token_ids = torch .empty_like (draft_token_ids )
285- sample_recovered_tokens_pytorch (
462+ sample_recovered_tokens_kernel [( batch_size , max_spec_len )] (
286463 recovered_token_ids ,
287464 cu_num_draft_tokens ,
288465 draft_token_ids ,
289466 draft_probs ,
290467 target_probs ,
291468 q ,
292469 vocab_size ,
293- IS_NGRAM = draft_probs is None ,
470+ triton .next_power_of_2 (vocab_size ),
471+ NO_DRAFT_PROBS = draft_probs is None ,
294472 )
295473 return recovered_token_ids
296474
0 commit comments