Skip to content

Commit d610b09

Browse files
swy20190Lord_of_Ironhillwhx-sjtu
authored andcommitted
[Kernel] add triton kernels for sampling (vllm-project#4550)
### What this PR does / why we need it? Replace pyorch implement of sampling with triton kernels ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.11.2 --------- Signed-off-by: Lord_of_Ironhill <[email protected]> Signed-off-by: whx-sjtu <[email protected]> Co-authored-by: Lord_of_Ironhill <[email protected]> Co-authored-by: whx-sjtu <[email protected]>
1 parent 8ad1e06 commit d610b09

File tree

1 file changed

+284
-45
lines changed

1 file changed

+284
-45
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 284 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.nn as nn
66
import vllm.v1.sample.rejection_sampler as rs
7+
from vllm.triton_utils import HAS_TRITON, tl, triton
78
from vllm.v1.sample.metadata import SamplingMetadata
89
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
910
apply_sampling_constraints,
@@ -149,25 +150,36 @@ def rejection_sample(
149150
if not sampling_metadata.all_random:
150151
# Rejection sampling for greedy sampling requests.
151152
target_argmax = target_probs.argmax(dim=-1)
152-
if min(num_draft_tokens) == 1 and max(
153-
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
154-
rejection_greedy_sample_spec_len_1_pytorch(
155-
output_token_ids,
156-
draft_token_ids,
157-
target_argmax,
158-
bonus_token_ids,
159-
)
160-
else:
161-
rejection_greedy_sample_pytorch(
153+
if HAS_TRITON:
154+
rejection_greedy_sample_kernel[(batch_size, )](
162155
output_token_ids,
163156
cu_num_draft_tokens,
164157
draft_token_ids,
165158
target_argmax,
166159
bonus_token_ids,
167-
num_draft_tokens,
168-
max_spec_len,
169160
is_greedy,
161+
max_spec_len,
170162
)
163+
else:
164+
if min(num_draft_tokens) == 1 and max(
165+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
166+
rejection_greedy_sample_spec_len_1_pytorch(
167+
output_token_ids,
168+
draft_token_ids,
169+
target_argmax,
170+
bonus_token_ids,
171+
)
172+
else:
173+
rejection_greedy_sample_pytorch(
174+
output_token_ids,
175+
cu_num_draft_tokens,
176+
draft_token_ids,
177+
target_argmax,
178+
bonus_token_ids,
179+
num_draft_tokens,
180+
max_spec_len,
181+
is_greedy,
182+
)
171183
if sampling_metadata.all_greedy:
172184
return output_token_ids
173185

@@ -194,21 +206,37 @@ def rejection_sample(
194206
)
195207

196208
# Rejection sampling for random sampling requests.
197-
rejection_random_sample_pytorch(
198-
output_token_ids,
199-
cu_num_draft_tokens,
200-
draft_token_ids,
201-
draft_probs,
202-
target_probs,
203-
bonus_token_ids,
204-
recovered_token_ids,
205-
uniform_probs,
206-
is_greedy,
207-
max_spec_len,
208-
vocab_size,
209-
IS_NGRAM=draft_probs is None,
210-
# num_warps=1,
211-
)
209+
if HAS_TRITON:
210+
rejection_random_sample_kernel[(batch_size, )](
211+
output_token_ids,
212+
cu_num_draft_tokens,
213+
draft_token_ids,
214+
draft_probs,
215+
target_probs,
216+
bonus_token_ids,
217+
recovered_token_ids,
218+
uniform_probs,
219+
is_greedy,
220+
max_spec_len,
221+
vocab_size,
222+
NO_DRAFT_PROBS=draft_probs is None,
223+
)
224+
else:
225+
rejection_random_sample_pytorch(
226+
output_token_ids,
227+
cu_num_draft_tokens,
228+
draft_token_ids,
229+
draft_probs,
230+
target_probs,
231+
bonus_token_ids,
232+
recovered_token_ids,
233+
uniform_probs,
234+
is_greedy,
235+
max_spec_len,
236+
vocab_size,
237+
IS_NGRAM=draft_probs is None,
238+
# num_warps=1,
239+
)
212240
return output_token_ids
213241

214242

@@ -241,14 +269,24 @@ def expand_batch_to_tokens(
241269
batch_size = x.shape[0]
242270
assert cu_num_tokens.shape[0] == batch_size
243271
expanded_x = x.new_empty(num_tokens)
244-
expand_pytorch(
245-
expanded_x,
246-
x,
247-
cu_num_tokens,
248-
replace_from,
249-
replace_to,
250-
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
251-
)
272+
if HAS_TRITON:
273+
expand_kernel[(batch_size, )](
274+
expanded_x,
275+
x,
276+
cu_num_tokens,
277+
replace_from,
278+
replace_to,
279+
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
280+
)
281+
else:
282+
expand_pytorch(
283+
expanded_x,
284+
x,
285+
cu_num_tokens,
286+
replace_from,
287+
replace_to,
288+
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
289+
)
252290
return expanded_x
253291

254292

@@ -282,16 +320,29 @@ def sample_recovered_tokens(
282320
q[i].exponential_(generator=generator)
283321

284322
recovered_token_ids = torch.empty_like(draft_token_ids)
285-
sample_recovered_tokens_pytorch(
286-
recovered_token_ids,
287-
cu_num_draft_tokens,
288-
draft_token_ids,
289-
draft_probs,
290-
target_probs,
291-
q,
292-
vocab_size,
293-
IS_NGRAM=draft_probs is None,
294-
)
323+
if HAS_TRITON:
324+
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
325+
recovered_token_ids,
326+
cu_num_draft_tokens,
327+
draft_token_ids,
328+
draft_probs,
329+
target_probs,
330+
q,
331+
vocab_size,
332+
triton.next_power_of_2(vocab_size),
333+
NO_DRAFT_PROBS=draft_probs is None,
334+
)
335+
else:
336+
sample_recovered_tokens_pytorch(
337+
recovered_token_ids,
338+
cu_num_draft_tokens,
339+
draft_token_ids,
340+
draft_probs,
341+
target_probs,
342+
q,
343+
vocab_size,
344+
IS_NGRAM=draft_probs is None,
345+
)
295346
return recovered_token_ids
296347

297348

@@ -504,4 +555,192 @@ def sample_recovered_tokens_pytorch(
504555
target_probs[token_idx, draft_token_id] = orig_prob
505556

506557

558+
@triton.jit(do_not_specialize=["max_spec_len"])
559+
def rejection_greedy_sample_kernel(
560+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
561+
cu_num_draft_tokens_ptr, # [batch_size]
562+
draft_token_ids_ptr, # [num_tokens]
563+
target_argmax_ptr, # [num_tokens]
564+
bonus_token_ids_ptr, # [batch_size]
565+
is_greedy_ptr, # [batch_size] or None
566+
max_spec_len,
567+
):
568+
req_idx = tl.program_id(0)
569+
# Because is_greedy_ptr is not Nonr at profiling run,
570+
# re-comilation may happen during runtime when is_greedy_ptr is None.
571+
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
572+
req_idx)
573+
if not is_greedy:
574+
# Early exit for non-greedy sampling requests
575+
return
576+
577+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
578+
req_idx - 1)
579+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
580+
num_draft_tokens = end_idx - start_idx
581+
582+
rejected = False
583+
for pos in range(num_draft_tokens):
584+
if not rejected:
585+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
586+
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
587+
tl.store(
588+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
589+
target_argmax_id,
590+
)
591+
if draft_token_id != target_argmax_id:
592+
# Reject
593+
rejected = True
594+
595+
if not rejected:
596+
# If all tokens are accepted, append the bonus token
597+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
598+
tl.store(
599+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
600+
num_draft_tokens,
601+
bonus_token_id,
602+
)
603+
604+
605+
@triton.jit(do_not_specialize=["max_spec_len"])
606+
def rejection_random_sample_kernel(
607+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
608+
cu_num_draft_tokens_ptr, # [batch_size]
609+
draft_token_ids_ptr, # [num_tokens]
610+
draft_probs_ptr, # [num_tokens, vocab_size] or None
611+
target_probs_ptr, # [num_tokens, vocab_size]
612+
bonus_token_ids_ptr, # [batch_size]
613+
recovered_token_ids_ptr, # [num_tokens]
614+
uniform_probs_ptr, # [num_tokens]
615+
is_greedy_ptr, # [batch_size]
616+
max_spec_len,
617+
vocab_size,
618+
NO_DRAFT_PROBS: tl.constexpr,
619+
):
620+
req_idx = tl.program_id(0)
621+
is_greedy = tl.load(is_greedy_ptr + req_idx)
622+
if is_greedy:
623+
# Early exost for greedy sampling requests
624+
return
625+
626+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
627+
req_idx - 1)
628+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
629+
num_draft_tokens = end_idx - start_idx
630+
631+
rejected = False
632+
for pos in range(num_draft_tokens):
633+
if not rejected:
634+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
635+
if NO_DRAFT_PROBS:
636+
draft_prob = 1
637+
else:
638+
draft_prob = tl.load(draft_probs_ptr +
639+
(start_idx + pos) * vocab_size +
640+
draft_token_id)
641+
target_prob = tl.load(target_probs_ptr +
642+
(start_idx + pos) * vocab_size +
643+
draft_token_id)
644+
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
645+
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
646+
# Accept
647+
token_id = draft_token_id
648+
else:
649+
# Reject. Use recovered token
650+
rejected = True
651+
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
652+
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
653+
token_id)
654+
655+
if not rejected:
656+
# If all tokens are accepted, append the bonus token
657+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
658+
tl.store(
659+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
660+
num_draft_tokens,
661+
bonus_token_id,
662+
)
663+
664+
665+
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
666+
def expand_kernel(
667+
output_ptr, # [num_tokens]
668+
input_ptr, # [batch_size]
669+
cu_num_tokens_ptr, # [batch_size]
670+
replace_from,
671+
replace_to,
672+
MAX_NUM_TOKENS: tl.constexpr,
673+
):
674+
req_idx = tl.program_id(0)
675+
if req_idx == 0:
676+
start_idx = 0
677+
else:
678+
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
679+
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
680+
num_tokens = end_idx - start_idx
681+
682+
src_val = tl.load(input_ptr + req_idx)
683+
src_val = tl.where(src_val == replace_from, replace_to, src_val)
684+
offset = tl.arange(0, MAX_NUM_TOKENS)
685+
tl.store(output_ptr + start_idx + offset,
686+
src_val,
687+
mask=offset < num_tokens)
688+
689+
690+
@triton.jit
691+
def sample_recovered_tokens_kernel(
692+
output_token_ids_ptr, # [num_tokens]
693+
cu_num_draft_tokens_ptr, # [batch_size]
694+
draft_token_ids_ptr, # [num_tokens]
695+
draft_probs_ptr, # [num_tokens, vocab_size] or None
696+
target_probs_ptr, # [num_tokens, vocab_size]
697+
q_ptr, # [batch_size, vocab_size]
698+
vocab_size,
699+
PADDED_VOCAB_SIZE: tl.constexpr,
700+
NO_DRAFT_PROBS: tl.constexpr,
701+
):
702+
req_idx = tl.program_id(0)
703+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
704+
req_idx - 1)
705+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
706+
num_draft_tokens = end_idx - start_idx
707+
708+
# Early exit for out-of-range positions
709+
pos = tl.program_id(1)
710+
if pos >= num_draft_tokens:
711+
return
712+
713+
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
714+
if NO_DRAFT_PROBS:
715+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
716+
prob = tl.load(
717+
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
718+
mask=((vocab_offset < vocab_size) &
719+
(vocab_offset != draft_token_id)),
720+
other=0,
721+
)
722+
else:
723+
draft_prob = tl.load(
724+
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
725+
mask=vocab_offset < vocab_size,
726+
other=0,
727+
)
728+
target_prob = tl.load(
729+
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
730+
mask=vocab_offset < vocab_size,
731+
other=0,
732+
)
733+
prob = tl.maximum(target_prob - draft_prob, 0)
734+
# We don't need `prob = prob / tl.sum(prob)` here because
735+
# `tl.argmax` will select the maximum value.
736+
737+
q = tl.load(
738+
q_ptr + req_idx * vocab_size + vocab_offset,
739+
mask=vocab_offset < vocab_size,
740+
other=float("-inf"),
741+
)
742+
recovered_id = tl.argmax(prob / q, axis=-1)
743+
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
744+
745+
507746
rs.expand_batch_to_tokens = expand_batch_to_tokens

0 commit comments

Comments
 (0)