Skip to content

Commit 353facf

Browse files
committed
triton code vector core change with new rejection_sampler.py
Signed-off-by: yuxingcyx <[email protected]>
1 parent eac72f5 commit 353facf

File tree

1 file changed

+167
-52
lines changed

1 file changed

+167
-52
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 167 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@
2121
MAX_SPEC_LEN = 32
2222

2323

24+
vectorcore_num = None
25+
device_properties = None
26+
27+
28+
if HAS_TRITON:
29+
from triton.runtime import driver
30+
device_properties = driver.active.utils.get_device_properties(torch.npu.current_device())
31+
vectorcore_num = device_properties['num_vectorcore']
32+
#get vector core number in order for later tiling
33+
34+
2435
class AscendRejectionSampler(RejectionSampler, nn.Module):
2536
"""
2637
The implementation strictly follows the algorithm described in
@@ -218,15 +229,36 @@ def rejection_sample(
218229
# Rejection sampling for greedy sampling requests.
219230
target_argmax = target_probs.argmax(dim=-1)
220231
if HAS_TRITON:
221-
rejection_greedy_sample_kernel[(batch_size, )](
222-
output_token_ids,
223-
cu_num_draft_tokens,
224-
draft_token_ids,
225-
target_argmax,
226-
bonus_token_ids,
227-
is_greedy,
228-
max_spec_len,
229-
)
232+
vec_len = batch_size
233+
n = cu_num_draft_tokens.numel()
234+
BLOCK_SIZE = 2
235+
grid = triton.cdiv(n, BLOCK_SIZE)
236+
if n >= vectorcore_num:
237+
grid = vectorcore_num # Empirically tuned value
238+
BLOCK_SIZE = triton.next_power_of_2(n // grid)
239+
240+
if min(num_draft_tokens) == 1 and max(
241+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
242+
rejection_greedy_sample_spec_len_1_triton[(grid,)](
243+
output_token_ids,
244+
draft_token_ids,
245+
target_argmax,
246+
bonus_token_ids,
247+
vec_len,
248+
BLOCK_SIZE=BLOCK_SIZE,
249+
)
250+
else:
251+
rejection_greedy_sample_triton[(grid,)](
252+
output_token_ids,
253+
cu_num_draft_tokens,
254+
draft_token_ids,
255+
target_argmax,
256+
bonus_token_ids,
257+
is_greedy,
258+
vec_len,
259+
max_spec_len,
260+
BLOCK_SIZE=BLOCK_SIZE,
261+
)
230262
else:
231263
if min(num_draft_tokens) == 1 and max(
232264
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
@@ -337,13 +369,23 @@ def expand_batch_to_tokens(
337369
assert cu_num_tokens.shape[0] == batch_size
338370
expanded_x = x.new_empty(num_tokens)
339371
if HAS_TRITON:
340-
expand_kernel[(batch_size, )](
372+
vec_len = batch_size
373+
n = cu_num_tokens.numel()
374+
BLOCK_SIZE = 2
375+
grid = triton.cdiv(n, BLOCK_SIZE)
376+
if n >= vectorcore_num:
377+
grid = vectorcore_num
378+
BLOCK_SIZE = triton.next_power_of_2(n // grid)
379+
380+
expand_kernel[(grid,)](
341381
expanded_x,
342382
x,
343383
cu_num_tokens,
344384
replace_from,
345385
replace_to,
386+
vec_len,
346387
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
388+
BLOCK_SIZE=BLOCK_SIZE,
347389
)
348390
else:
349391
expand_pytorch(
@@ -626,50 +668,115 @@ def sample_recovered_tokens_pytorch(
626668

627669

628670
@triton.jit(do_not_specialize=["max_spec_len"])
629-
def rejection_greedy_sample_kernel(
671+
def bonus_renew_1(
672+
bonus_token_ids_ptr,
673+
position,
674+
output_token_ids_ptr,
675+
):
676+
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
677+
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)
678+
679+
680+
@triton.jit(do_not_specialize=["max_spec_len"])
681+
def rejection_greedy_sample_spec_len_1_triton(
682+
output_token_ids_ptr, # [batch_size, 2]
683+
draft_token_ids_ptr, # [num_tokens]
684+
target_argmax_ptr, # [num_tokens]
685+
bonus_token_ids_ptr,
686+
vec_len,
687+
BLOCK_SIZE: tl.constexpr,
688+
):
689+
block_idx = tl.program_id(0)
690+
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
691+
mask = offset < vec_len
692+
693+
draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
694+
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
695+
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
696+
697+
for pos in tl.range(0, BLOCK_SIZE):
698+
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
699+
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
700+
position = block_idx * BLOCK_SIZE + pos
701+
if draft_token_id1 == target_argmax1:
702+
bonus_renew_1(
703+
bonus_token_ids_ptr,
704+
position,
705+
output_token_ids_ptr,
706+
)
707+
708+
709+
@triton.jit(do_not_specialize=["max_spec_len"])
710+
def bonus_renew(
711+
bonus_token_ids_ptr,
712+
position,
713+
output_token_ids_ptr,
714+
max_spec_len,
715+
num_tokens1,
716+
):
717+
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
718+
tl.store(
719+
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
720+
bonus_token_id)
721+
722+
723+
@triton.jit(do_not_specialize=["max_spec_len"])
724+
def rejection_greedy_sample_triton(
630725
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
631726
cu_num_draft_tokens_ptr, # [batch_size]
632727
draft_token_ids_ptr, # [num_tokens]
633728
target_argmax_ptr, # [num_tokens]
634729
bonus_token_ids_ptr, # [batch_size]
635730
is_greedy_ptr, # [batch_size] or None
731+
vec_len,
636732
max_spec_len,
733+
BLOCK_SIZE: tl.constexpr,
637734
):
638-
req_idx = tl.program_id(0)
639-
# Because is_greedy_ptr is not Nonr at profiling run,
640-
# re-comilation may happen during runtime when is_greedy_ptr is None.
641-
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
642-
req_idx)
643-
if not is_greedy:
644-
# Early exit for non-greedy sampling requests
645-
return
735+
block_idx = tl.program_id(0)
736+
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
737+
mask = offset < vec_len
646738

647-
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
648-
req_idx - 1)
649-
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
739+
if is_greedy_ptr is None:
740+
is_greedy_mask = mask
741+
else:
742+
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
743+
is_greedy_mask = mask & (is_greedy != 0)
744+
745+
start_idx = tl.where(
746+
offset == 0, 0,
747+
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
748+
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
650749
num_draft_tokens = end_idx - start_idx
651750

652-
rejected = False
653-
for pos in range(num_draft_tokens):
654-
if not rejected:
655-
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
656-
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
657-
tl.store(
658-
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
659-
target_argmax_id,
660-
)
661-
if draft_token_id != target_argmax_id:
662-
# Reject
663-
rejected = True
751+
for pos in tl.range(0, BLOCK_SIZE):
752+
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
753+
if num_tokens1 != 0:
754+
rejected = False
755+
start_idx1 = tl.get_element(start_idx, (pos, ))
756+
position = block_idx * BLOCK_SIZE + pos
757+
for i in range(num_tokens1):
758+
if not rejected:
759+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 +
760+
i)
761+
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 +
762+
i)
763+
tl.store(
764+
output_token_ids_ptr + position * (max_spec_len + 1) +
765+
i,
766+
target_argmax_id,
767+
)
768+
if draft_token_id != target_argmax_id:
769+
# Reject.
770+
rejected = True
664771

665-
if not rejected:
666-
# If all tokens are accepted, append the bonus token
667-
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
668-
tl.store(
669-
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
670-
num_draft_tokens,
671-
bonus_token_id,
672-
)
772+
if not rejected:
773+
bonus_renew(
774+
bonus_token_ids_ptr,
775+
position,
776+
output_token_ids_ptr,
777+
max_spec_len,
778+
num_tokens1,
779+
)
673780

674781

675782
@triton.jit(do_not_specialize=["max_spec_len"])
@@ -739,22 +846,30 @@ def expand_kernel(
739846
cu_num_tokens_ptr, # [batch_size]
740847
replace_from,
741848
replace_to,
849+
vec_len,
742850
MAX_NUM_TOKENS: tl.constexpr,
851+
BLOCK_SIZE: tl.constexpr,
743852
):
744853
req_idx = tl.program_id(0)
745-
if req_idx == 0:
746-
start_idx = 0
747-
else:
748-
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
749-
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
854+
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
855+
len_mask = offset < vec_len
856+
857+
start_idx = tl.where(offset == 0, 0,
858+
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
859+
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
750860
num_tokens = end_idx - start_idx
751861

752-
src_val = tl.load(input_ptr + req_idx)
862+
src_val = tl.load(input_ptr + offset, len_mask)
753863
src_val = tl.where(src_val == replace_from, replace_to, src_val)
754-
offset = tl.arange(0, MAX_NUM_TOKENS)
755-
tl.store(output_ptr + start_idx + offset,
756-
src_val,
757-
mask=offset < num_tokens)
864+
865+
for i in tl.range(0, BLOCK_SIZE):
866+
num_tokens1 = tl.get_element(num_tokens, (i, ))
867+
start_idx1 = tl.get_element(start_idx, (i, ))
868+
src_val1 = tl.get_element(src_val, (i, ))
869+
offset1 = tl.arange(0, MAX_NUM_TOKENS)
870+
tl.store(output_ptr + start_idx1 + offset1,
871+
src_val1,
872+
mask=offset1 < num_tokens1)
758873

759874

760875
@triton.jit

0 commit comments

Comments
 (0)