Skip to content

Commit 20512f7

Browse files
committed
triton code vector core rejection_sampler.py
Signed-off-by: yuxingcyx <[email protected]>
1 parent eac72f5 commit 20512f7

File tree

1 file changed

+161
-120
lines changed

1 file changed

+161
-120
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 161 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@
33

44
import torch
55
import torch.nn as nn
6-
import torch_npu
6+
import triton.runtime.driver as driver
77
import vllm.v1.sample.rejection_sampler as rs
88
from vllm.triton_utils import HAS_TRITON, tl, triton
99
from vllm.v1.sample.metadata import SamplingMetadata
10-
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
1110
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
11+
apply_sampling_constraints,
1212
generate_uniform_probs)
1313
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1414

15-
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
16-
1715
PLACEHOLDER_TOKEN_ID = -1
1816
GREEDY_TEMPERATURE = -1
1917
# Maximum number of speculative draft tokens allowed per request in a single
2018
# step. This value is chosen to be large enough to handle typical use cases.
2119
MAX_SPEC_LEN = 32
2220

21+
device_properties = driver.active.utils.get_device_properties(torch.npu.current_device())
22+
vectorcore_num = device_properties['num_vectorcore']
23+
#get vector core number in order for later tiling
2324

2425
class AscendRejectionSampler(RejectionSampler, nn.Module):
2526
"""
@@ -107,70 +108,6 @@ def forward(
107108
return output_token_ids
108109

109110

110-
def apply_sampling_constraints(
111-
logits: torch.Tensor, # [num_tokens, vocab_size]
112-
cu_num_draft_tokens: torch.Tensor, # [batch_size]
113-
sampling_metadata: SamplingMetadata,
114-
) -> torch.Tensor:
115-
"""Process logits based on sampling metadata.
116-
117-
This function applies temperature scaling to the logits,
118-
as well as top-k and top-p. For greedy decoding, it returns
119-
the original logits.
120-
121-
Args:
122-
logits: Input logits tensor to be processed.
123-
cu_num_draft_tokens: Cumulative number of draft tokens.
124-
sampling_metadata: Metadata containing sampling parameters such as
125-
temperature and whether greedy sampling is used.
126-
127-
Returns:
128-
torch.Tensor: Processed logits if non-greedy sampling is used,
129-
otherwise returns the original logits.
130-
"""
131-
assert logits.ndim == 2
132-
assert cu_num_draft_tokens.ndim == 1
133-
if sampling_metadata.all_greedy:
134-
return logits
135-
136-
num_tokens = logits.shape[0]
137-
temperature = expand_batch_to_tokens(
138-
sampling_metadata.temperature,
139-
cu_num_draft_tokens,
140-
num_tokens,
141-
replace_from=GREEDY_TEMPERATURE,
142-
replace_to=1,
143-
)
144-
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
145-
logits.div_(temperature.unsqueeze(-1))
146-
147-
# Get expanded top_k and top_p tensors.
148-
top_k = None
149-
if sampling_metadata.top_k is not None:
150-
top_k = expand_batch_to_tokens(
151-
sampling_metadata.top_k,
152-
cu_num_draft_tokens,
153-
num_tokens,
154-
)
155-
top_p = None
156-
if sampling_metadata.top_p is not None:
157-
top_p = expand_batch_to_tokens(
158-
sampling_metadata.top_p,
159-
cu_num_draft_tokens,
160-
num_tokens,
161-
)
162-
163-
if get_ascend_device_type(
164-
) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int(
165-
top_k.max()) <= 1024:
166-
return torch_npu.npu_top_k_top_p(logits, top_p.to(torch.bfloat16),
167-
top_k)
168-
else:
169-
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
170-
# which is slow for large vocab sizes. This may cause performance issues.
171-
return apply_top_k_top_p(logits, top_k, top_p)
172-
173-
174111
def rejection_sample(
175112
# [num_tokens]
176113
draft_token_ids: torch.Tensor,
@@ -218,15 +155,36 @@ def rejection_sample(
218155
# Rejection sampling for greedy sampling requests.
219156
target_argmax = target_probs.argmax(dim=-1)
220157
if HAS_TRITON:
221-
rejection_greedy_sample_kernel[(batch_size, )](
222-
output_token_ids,
223-
cu_num_draft_tokens,
224-
draft_token_ids,
225-
target_argmax,
226-
bonus_token_ids,
227-
is_greedy,
228-
max_spec_len,
229-
)
158+
vec_len = batch_size
159+
n = cu_num_draft_tokens.numel()
160+
BLOCK_SIZE = 2
161+
grid = triton.cdiv(n, BLOCK_SIZE)
162+
if n >= vectorcore_num:
163+
grid = vectorcore_num # Empirically tuned value
164+
BLOCK_SIZE = triton.next_power_of_2(n // grid)
165+
166+
if min(num_draft_tokens) == 1 and max(
167+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
168+
rejection_greedy_sample_spec_len_1_triton[(grid, )](
169+
output_token_ids,
170+
draft_token_ids,
171+
target_argmax,
172+
bonus_token_ids,
173+
vec_len,
174+
BLOCK_SIZE=BLOCK_SIZE,
175+
)
176+
else:
177+
rejection_greedy_sample_triton[(grid, )](
178+
output_token_ids,
179+
cu_num_draft_tokens,
180+
draft_token_ids,
181+
target_argmax,
182+
bonus_token_ids,
183+
is_greedy,
184+
vec_len,
185+
max_spec_len,
186+
BLOCK_SIZE=BLOCK_SIZE,
187+
)
230188
else:
231189
if min(num_draft_tokens) == 1 and max(
232190
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
@@ -337,13 +295,23 @@ def expand_batch_to_tokens(
337295
assert cu_num_tokens.shape[0] == batch_size
338296
expanded_x = x.new_empty(num_tokens)
339297
if HAS_TRITON:
340-
expand_kernel[(batch_size, )](
298+
vec_len = batch_size
299+
n = cu_num_tokens.numel()
300+
BLOCK_SIZE = 2
301+
grid = triton.cdiv(n, BLOCK_SIZE)
302+
if n >= vectorcore_num:
303+
grid = vectorcore_num
304+
BLOCK_SIZE = triton.next_power_of_2(n // grid)
305+
306+
expand_kernel[(grid, )](
341307
expanded_x,
342308
x,
343309
cu_num_tokens,
344310
replace_from,
345311
replace_to,
312+
vec_len,
346313
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
314+
BLOCK_SIZE=BLOCK_SIZE,
347315
)
348316
else:
349317
expand_pytorch(
@@ -626,50 +594,115 @@ def sample_recovered_tokens_pytorch(
626594

627595

628596
@triton.jit(do_not_specialize=["max_spec_len"])
629-
def rejection_greedy_sample_kernel(
597+
def bonus_renew_1(
598+
bonus_token_ids_ptr,
599+
position,
600+
output_token_ids_ptr,
601+
):
602+
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
603+
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)
604+
605+
606+
@triton.jit(do_not_specialize=["max_spec_len"])
607+
def rejection_greedy_sample_spec_len_1_triton(
608+
output_token_ids_ptr, # [batch_size, 2]
609+
draft_token_ids_ptr, # [num_tokens]
610+
target_argmax_ptr, # [num_tokens]
611+
bonus_token_ids_ptr,
612+
vec_len,
613+
BLOCK_SIZE: tl.constexpr,
614+
):
615+
block_idx = tl.program_id(0)
616+
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
617+
mask = offset < vec_len
618+
619+
draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
620+
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
621+
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
622+
623+
for pos in tl.range(0, BLOCK_SIZE):
624+
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
625+
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
626+
position = block_idx * BLOCK_SIZE + pos
627+
if draft_token_id1 == target_argmax1:
628+
bonus_renew_1(
629+
bonus_token_ids_ptr,
630+
position,
631+
output_token_ids_ptr,
632+
)
633+
634+
635+
@triton.jit(do_not_specialize=["max_spec_len"])
636+
def bonus_renew(
637+
bonus_token_ids_ptr,
638+
position,
639+
output_token_ids_ptr,
640+
max_spec_len,
641+
num_tokens1,
642+
):
643+
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
644+
tl.store(
645+
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
646+
bonus_token_id)
647+
648+
649+
@triton.jit(do_not_specialize=["max_spec_len"])
650+
def rejection_greedy_sample_triton(
630651
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
631652
cu_num_draft_tokens_ptr, # [batch_size]
632653
draft_token_ids_ptr, # [num_tokens]
633654
target_argmax_ptr, # [num_tokens]
634655
bonus_token_ids_ptr, # [batch_size]
635656
is_greedy_ptr, # [batch_size] or None
657+
vec_len,
636658
max_spec_len,
659+
BLOCK_SIZE: tl.constexpr,
637660
):
638-
req_idx = tl.program_id(0)
639-
# Because is_greedy_ptr is not Nonr at profiling run,
640-
# re-comilation may happen during runtime when is_greedy_ptr is None.
641-
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
642-
req_idx)
643-
if not is_greedy:
644-
# Early exit for non-greedy sampling requests
645-
return
661+
block_idx = tl.program_id(0)
662+
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
663+
mask = offset < vec_len
646664

647-
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
648-
req_idx - 1)
649-
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
665+
if is_greedy_ptr is None:
666+
is_greedy_mask = mask
667+
else:
668+
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
669+
is_greedy_mask = mask & (is_greedy != 0)
670+
671+
start_idx = tl.where(
672+
offset == 0, 0,
673+
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
674+
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
650675
num_draft_tokens = end_idx - start_idx
651676

652-
rejected = False
653-
for pos in range(num_draft_tokens):
654-
if not rejected:
655-
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
656-
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
657-
tl.store(
658-
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
659-
target_argmax_id,
660-
)
661-
if draft_token_id != target_argmax_id:
662-
# Reject
663-
rejected = True
677+
for pos in tl.range(0, BLOCK_SIZE):
678+
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
679+
if num_tokens1 != 0:
680+
rejected = False
681+
start_idx1 = tl.get_element(start_idx, (pos, ))
682+
position = block_idx * BLOCK_SIZE + pos
683+
for i in range(num_tokens1):
684+
if not rejected:
685+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 +
686+
i)
687+
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 +
688+
i)
689+
tl.store(
690+
output_token_ids_ptr + position * (max_spec_len + 1) +
691+
i,
692+
target_argmax_id,
693+
)
694+
if draft_token_id != target_argmax_id:
695+
# Reject.
696+
rejected = True
664697

665-
if not rejected:
666-
# If all tokens are accepted, append the bonus token
667-
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
668-
tl.store(
669-
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
670-
num_draft_tokens,
671-
bonus_token_id,
672-
)
698+
if not rejected:
699+
bonus_renew(
700+
bonus_token_ids_ptr,
701+
position,
702+
output_token_ids_ptr,
703+
max_spec_len,
704+
num_tokens1,
705+
)
673706

674707

675708
@triton.jit(do_not_specialize=["max_spec_len"])
@@ -739,22 +772,30 @@ def expand_kernel(
739772
cu_num_tokens_ptr, # [batch_size]
740773
replace_from,
741774
replace_to,
775+
vec_len,
742776
MAX_NUM_TOKENS: tl.constexpr,
777+
BLOCK_SIZE: tl.constexpr,
743778
):
744779
req_idx = tl.program_id(0)
745-
if req_idx == 0:
746-
start_idx = 0
747-
else:
748-
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
749-
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
780+
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
781+
len_mask = offset < vec_len
782+
783+
start_idx = tl.where(offset == 0, 0,
784+
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
785+
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
750786
num_tokens = end_idx - start_idx
751787

752-
src_val = tl.load(input_ptr + req_idx)
788+
src_val = tl.load(input_ptr + offset, len_mask)
753789
src_val = tl.where(src_val == replace_from, replace_to, src_val)
754-
offset = tl.arange(0, MAX_NUM_TOKENS)
755-
tl.store(output_ptr + start_idx + offset,
756-
src_val,
757-
mask=offset < num_tokens)
790+
791+
for i in tl.range(0, BLOCK_SIZE):
792+
num_tokens1 = tl.get_element(num_tokens, (i, ))
793+
start_idx1 = tl.get_element(start_idx, (i, ))
794+
src_val1 = tl.get_element(src_val, (i, ))
795+
offset1 = tl.arange(0, MAX_NUM_TOKENS)
796+
tl.store(output_ptr + start_idx1 + offset1,
797+
src_val1,
798+
mask=offset1 < num_tokens1)
758799

759800

760801
@triton.jit

0 commit comments

Comments
 (0)