Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 203 additions & 25 deletions vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
apply_sampling_constraints,
generate_uniform_probs)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
import triton
import triton.language as tl
import triton.runtime.driver as driver

import torch_npu._inductor
PLACEHOLDER_TOKEN_ID = -1
GREEDY_TEMPERATURE = -1
# Maximum number of speculative draft tokens allowed per request in a single
Expand Down Expand Up @@ -102,6 +106,189 @@ def forward(
)
return output_token_ids

# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
):
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests.
return

start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx

rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True

if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)


# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exit for greedy sampling requests.
return

start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx

rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
)

if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)


# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
MAX_NUM_TOKENS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0: # noqa: SIM108
start_idx = 0
else:
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
num_tokens = end_idx - start_idx

src_val = tl.load(input_ptr + req_idx)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)


@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx

# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return

vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)),
other=0,
)
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.

q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"),
)
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)

def rejection_sample(
# [num_tokens]
Expand Down Expand Up @@ -149,25 +336,15 @@ def rejection_sample(
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
)
else:
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
)
rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
if sampling_metadata.all_greedy:
return output_token_ids

Expand All @@ -194,7 +371,8 @@ def rejection_sample(
)

# Rejection sampling for random sampling requests.
rejection_random_sample_pytorch(
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
Expand All @@ -206,8 +384,7 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
NO_DRAFT_PROBS=draft_probs is None,
)
return output_token_ids

Expand Down Expand Up @@ -241,7 +418,7 @@ def expand_batch_to_tokens(
batch_size = x.shape[0]
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
expand_pytorch(
expand_kernel[(batch_size,)](
expanded_x,
x,
cu_num_tokens,
Expand Down Expand Up @@ -282,15 +459,16 @@ def sample_recovered_tokens(
q[i].exponential_(generator=generator)

recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_pytorch(
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
q,
vocab_size,
IS_NGRAM=draft_probs is None,
triton.next_power_of_2(vocab_size),
NO_DRAFT_PROBS=draft_probs is None,
)
return recovered_token_ids

Expand Down
Loading
Loading