Skip to content

Commit b4cbe62

Browse files
committed
triton code remove autotune for rejection_sampler.py
Signed-off-by: yuxingcyx <[email protected]>
1 parent 1c70f5c commit b4cbe62

File tree

1 file changed

+156
-52
lines changed

1 file changed

+156
-52
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 156 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,36 @@ def rejection_sample(
151151
# Rejection sampling for greedy sampling requests.
152152
target_argmax = target_probs.argmax(dim=-1)
153153
if HAS_TRITON:
154-
rejection_greedy_sample_kernel[(batch_size, )](
155-
output_token_ids,
156-
cu_num_draft_tokens,
157-
draft_token_ids,
158-
target_argmax,
159-
bonus_token_ids,
160-
is_greedy,
161-
max_spec_len,
162-
)
154+
vec_len = batch_size
155+
n = cu_num_draft_tokens.numel()
156+
BLOCK_SIZE = 2
157+
grid = triton.cdiv(n, BLOCK_SIZE)
158+
if n >= 40:
159+
grid = 40
160+
BLOCK_SIZE = triton.next_power_of_2(n // grid)
161+
162+
if min(num_draft_tokens) == 1 and max(
163+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
164+
rejection_greedy_sample_spec_len_1_triton[(grid, )](
165+
output_token_ids,
166+
draft_token_ids,
167+
target_argmax,
168+
bonus_token_ids,
169+
vec_len,
170+
BLOCK_SIZE = BLOCK_SIZE,
171+
)
172+
else:
173+
rejection_greedy_sample_triton[(grid, )](
174+
output_token_ids,
175+
cu_num_draft_tokens,
176+
draft_token_ids,
177+
target_argmax,
178+
bonus_token_ids,
179+
is_greedy,
180+
vec_len,
181+
max_spec_len,
182+
BLOCK_SIZE=BLOCK_SIZE,
183+
)
163184
else:
164185
if min(num_draft_tokens) == 1 and max(
165186
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
@@ -270,13 +291,23 @@ def expand_batch_to_tokens(
270291
assert cu_num_tokens.shape[0] == batch_size
271292
expanded_x = x.new_empty(num_tokens)
272293
if HAS_TRITON:
273-
expand_kernel[(batch_size, )](
294+
vec_len = batch_size
295+
n = cu_num_tokens.numel()
296+
BLOCK_SIZE = 2
297+
grid = triton.cdiv(n, BLOCK_SIZE)
298+
if n >= 40:
299+
grid = 40
300+
BLOCK_SIZE = triton.next_power_of_2(n // grid)
301+
302+
expand_kernel[(grid, )](
274303
expanded_x,
275304
x,
276305
cu_num_tokens,
277306
replace_from,
278307
replace_to,
308+
vec_len,
279309
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
310+
BLOCK_SIZE = BLOCK_SIZE,
280311
)
281312
else:
282313
expand_pytorch(
@@ -559,50 +590,115 @@ def sample_recovered_tokens_pytorch(
559590

560591

561592
@triton.jit(do_not_specialize=["max_spec_len"])
562-
def rejection_greedy_sample_kernel(
593+
def bonus_renew_1(
594+
bonus_token_ids_ptr,
595+
position,
596+
output_token_ids_ptr,
597+
):
598+
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
599+
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)
600+
601+
602+
@triton.jit(do_not_specialize=["max_spec_len"])
603+
def rejection_greedy_sample_spec_len_1_triton(
604+
output_token_ids_ptr, # [batch_size, 2]
605+
draft_token_ids_ptr, # [num_tokens]
606+
target_argmax_ptr, # [num_tokens]
607+
bonus_token_ids_ptr,
608+
vec_len,
609+
BLOCK_SIZE: tl.constexpr,
610+
):
611+
block_idx = tl.program_id(0)
612+
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
613+
mask = offset < vec_len
614+
615+
draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
616+
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
617+
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
618+
619+
for pos in tl.arange(0, BLOCK_SIZE):
620+
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
621+
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
622+
position = block_idx * BLOCK_SIZE + pos
623+
if draft_token_id1 == target_argmax1:
624+
bonus_renew_1(
625+
bonus_token_ids_ptr,
626+
position,
627+
output_token_ids_ptr,
628+
)
629+
630+
631+
@triton.jit(do_not_specialize=["max_spec_len"])
632+
def bonus_renew(
633+
bonus_token_ids_ptr,
634+
position,
635+
output_token_ids_ptr,
636+
max_spec_len,
637+
num_tokens1,
638+
):
639+
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
640+
tl.store(
641+
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
642+
bonus_token_id)
643+
644+
645+
@triton.jit(do_not_specialize=["max_spec_len"])
646+
def rejection_greedy_sample_triton(
563647
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
564648
cu_num_draft_tokens_ptr, # [batch_size]
565649
draft_token_ids_ptr, # [num_tokens]
566650
target_argmax_ptr, # [num_tokens]
567651
bonus_token_ids_ptr, # [batch_size]
568652
is_greedy_ptr, # [batch_size] or None
653+
vec_len,
569654
max_spec_len,
655+
BLOCK_SIZE: tl.constexpr,
570656
):
571-
req_idx = tl.program_id(0)
572-
# Because is_greedy_ptr is not Nonr at profiling run,
573-
# re-comilation may happen during runtime when is_greedy_ptr is None.
574-
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
575-
req_idx)
576-
if not is_greedy:
577-
# Early exit for non-greedy sampling requests
578-
return
657+
block_idx = tl.program_id(0)
658+
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
659+
mask = offset < vec_len
579660

580-
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
581-
req_idx - 1)
582-
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
661+
if is_greedy_ptr is None:
662+
is_greedy_mask = mask
663+
else:
664+
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
665+
is_greedy_mask = mask & (is_greedy != 0)
666+
667+
start_idx = tl.where(
668+
offset == 0, 0,
669+
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
670+
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
583671
num_draft_tokens = end_idx - start_idx
584672

585-
rejected = False
586-
for pos in range(num_draft_tokens):
587-
if not rejected:
588-
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
589-
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
590-
tl.store(
591-
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
592-
target_argmax_id,
593-
)
594-
if draft_token_id != target_argmax_id:
595-
# Reject
596-
rejected = True
673+
for pos in tl.range(0, BLOCK_SIZE):
674+
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
675+
if num_tokens1 != 0:
676+
rejected = False
677+
start_idx1 = tl.get_element(start_idx, (pos, ))
678+
position = block_idx * BLOCK_SIZE + pos
679+
for i in range(num_tokens1):
680+
if not rejected:
681+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 +
682+
i)
683+
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 +
684+
i)
685+
tl.store(
686+
output_token_ids_ptr + position * (max_spec_len + 1) +
687+
i,
688+
target_argmax_id,
689+
)
690+
if draft_token_id != target_argmax_id:
691+
# Reject.
692+
rejected = True
597693

598-
if not rejected:
599-
# If all tokens are accepted, append the bonus token
600-
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
601-
tl.store(
602-
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
603-
num_draft_tokens,
604-
bonus_token_id,
605-
)
694+
if not rejected:
695+
bonus_renew(
696+
bonus_token_ids_ptr,
697+
position,
698+
output_token_ids_ptr,
699+
max_spec_len,
700+
num_tokens1,
701+
)
606702

607703

608704
@triton.jit(do_not_specialize=["max_spec_len"])
@@ -672,22 +768,30 @@ def expand_kernel(
672768
cu_num_tokens_ptr, # [batch_size]
673769
replace_from,
674770
replace_to,
771+
vec_len,
675772
MAX_NUM_TOKENS: tl.constexpr,
773+
BLOCK_SIZE: tl.constexpr,
676774
):
677775
req_idx = tl.program_id(0)
678-
if req_idx == 0:
679-
start_idx = 0
680-
else:
681-
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
682-
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
776+
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
777+
len_mask = offset < vec_len
778+
779+
start_idx = tl.where(offset == 0, 0,
780+
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
781+
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
683782
num_tokens = end_idx - start_idx
684783

685-
src_val = tl.load(input_ptr + req_idx)
784+
src_val = tl.load(input_ptr + offset, len_mask)
686785
src_val = tl.where(src_val == replace_from, replace_to, src_val)
687-
offset = tl.arange(0, MAX_NUM_TOKENS)
688-
tl.store(output_ptr + start_idx + offset,
689-
src_val,
690-
mask=offset < num_tokens)
786+
787+
for i in tl.range(0, BLOCK_SIZE):
788+
num_tokens1 = tl.get_element(num_tokens, (i, ))
789+
start_idx1 = tl.get_element(start_idx, (i, ))
790+
src_val1 = tl.get_element(src_val, (i, ))
791+
offset1 = tl.arange(0, MAX_NUM_TOKENS)
792+
tl.store(output_ptr + start_idx1 + offset1,
793+
src_val1,
794+
mask=offset1 < num_tokens1)
691795

692796

693797
@triton.jit

0 commit comments

Comments
 (0)