Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 166 additions & 52 deletions vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32

vectorcore_num = None
device_properties = None

if HAS_TRITON:
from triton.runtime import driver # type: ignore
device_properties = driver.active.utils.get_device_properties(
torch.npu.current_device())
vectorcore_num = device_properties['num_vectorcore']
#get vector core number in order for later tiling


class AscendRejectionSampler(RejectionSampler, nn.Module):
"""
Expand Down Expand Up @@ -218,15 +228,36 @@ def rejection_sample(
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if HAS_TRITON:
rejection_greedy_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
vec_len = batch_size
n = cu_num_draft_tokens.numel()
BLOCK_SIZE = 2
grid = triton.cdiv(n, BLOCK_SIZE)
if n >= vectorcore_num:
grid = vectorcore_num # Empirically tuned value
BLOCK_SIZE = triton.next_power_of_2(n // grid)

if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_triton[(grid, )](
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
vec_len,
BLOCK_SIZE=BLOCK_SIZE,
)
else:
rejection_greedy_sample_triton[(grid, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
vec_len,
max_spec_len,
BLOCK_SIZE=BLOCK_SIZE,
)
else:
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
Expand Down Expand Up @@ -337,13 +368,23 @@ def expand_batch_to_tokens(
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
if HAS_TRITON:
expand_kernel[(batch_size, )](
vec_len = batch_size
n = cu_num_tokens.numel()
BLOCK_SIZE = 2
grid = triton.cdiv(n, BLOCK_SIZE)
if n >= vectorcore_num:
grid = vectorcore_num
BLOCK_SIZE = triton.next_power_of_2(n // grid)

expand_kernel[(grid, )](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
vec_len,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
BLOCK_SIZE=BLOCK_SIZE,
)
else:
expand_pytorch(
Expand Down Expand Up @@ -626,50 +667,115 @@ def sample_recovered_tokens_pytorch(


@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
def bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)


@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_spec_len_1_triton(
output_token_ids_ptr, # [batch_size, 2]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr,
vec_len,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len

draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)

for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
)


@triton.jit(do_not_specialize=["max_spec_len"])
def bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
bonus_token_id)


@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_triton(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
vec_len,
max_spec_len,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# Because is_greedy_ptr is not Nonr at profiling run,
# re-comilation may happen during runtime when is_greedy_ptr is None.
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests
return
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len

start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
if is_greedy_ptr is None:
is_greedy_mask = mask
else:
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
is_greedy_mask = mask & (is_greedy != 0)

start_idx = tl.where(
offset == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
num_draft_tokens = end_idx - start_idx

rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject
rejected = True
for pos in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
if num_tokens1 != 0:
rejected = False
start_idx1 = tl.get_element(start_idx, (pos, ))
position = block_idx * BLOCK_SIZE + pos
for i in range(num_tokens1):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 +
i)
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 +
i)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) +
i,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True

if not rejected:
# If all tokens are accepted, append the bonus token
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens,
bonus_token_id,
)
if not rejected:
bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
Comment on lines +771 to +777
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment to state why you encapsulate this method.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For certain inputs, each request does not go through the 'if not rejected' branch. After optimizing and collecting profiling data, it was found that this branch still accounted for a significant portion of MTE3 transfer time. Removing this branch does not affect accuracy and improves performance by 40%. Considering that this branch may still be executed for other inputs, these steps were written into a new operator, which is then called within the operator. Experiments show that performance improves to varying degrees for different inputs. After operator optimization, the setup is as follows: the part on the right represents the internal calls of the 'rejection_greedy_sample_kernel' operator to the left operator 'bonus_renew'.

)


@triton.jit(do_not_specialize=["max_spec_len"])
Expand Down Expand Up @@ -739,22 +845,30 @@ def expand_kernel(
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
vec_len,
MAX_NUM_TOKENS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
len_mask = offset < vec_len

start_idx = tl.where(offset == 0, 0,
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
num_tokens = end_idx - start_idx

src_val = tl.load(input_ptr + req_idx)
src_val = tl.load(input_ptr + offset, len_mask)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset,
src_val,
mask=offset < num_tokens)

for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_tokens, (i, ))
start_idx1 = tl.get_element(start_idx, (i, ))
src_val1 = tl.get_element(src_val, (i, ))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1,
src_val1,
mask=offset1 < num_tokens1)
Comment on lines +864 to +871
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This loop for i in tl.range(0, BLOCK_SIZE) serializes the processing of elements within a thread block by using tl.get_element. This is a Triton anti-pattern that underutilizes the GPU's parallel processing capabilities. While vectorizing a scatter operation with variable-sized segments is non-trivial, the current implementation is highly inefficient. A more performant, vectorized approach should be used to avoid this serialization.



@triton.jit
Expand Down
Loading