Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/ut/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,18 @@
GREEDY_TEMPERATURE = 0.0
MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens

original_tensor = torch.tensor


def mock_tensor_pin_memory(*args, **kwargs):
if kwargs.get('pin_memory', False):
kwargs['pin_memory'] = False
return original_tensor(*args, **kwargs)


class TestAscendRejectionSampler(TestBase):

@patch('torch.tensor', new=mock_tensor_pin_memory)
def test_rejection_greedy_sample_pytorch(self):
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
batch_size = 2
Expand Down Expand Up @@ -60,6 +69,7 @@ def test_rejection_greedy_sample_pytorch(self):
assert output_token_ids[1, 0].item() == 20
assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID

@patch('torch.tensor', new=mock_tensor_pin_memory)
def test_rejection_random_sample_pytorch(self):
"""Test random rejection sampling: accept based on uniform probability"""
batch_size = 2
Expand Down Expand Up @@ -104,6 +114,7 @@ def test_rejection_random_sample_pytorch(self):
assert output_token_ids[0, 1].item() == 0
assert output_token_ids[0, 2].item() == 100

@patch('torch.tensor', new=mock_tensor_pin_memory)
def test_expand_pytorch(self):
"""Test expand_pytorch functionality"""
input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32)
Expand All @@ -122,6 +133,7 @@ def test_expand_pytorch(self):
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
assert torch.equal(output_ptr, expected)

@patch('torch.tensor', new=mock_tensor_pin_memory)
def test_expand_batch_to_tokens(self):
"""Test expand_batch_to_tokens wrapper"""
x = torch.tensor([10, 20, 30])
Expand All @@ -141,6 +153,7 @@ def test_expand_batch_to_tokens(self):
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
assert torch.equal(result, expected)

@patch('torch.tensor', new=mock_tensor_pin_memory)
def test_sample_recovered_tokens_pytorch_ngram(self):
"""Test recovered token sampling under n-gram mode"""
output_token_ids = torch.empty(2, dtype=torch.int32)
Expand Down Expand Up @@ -171,6 +184,7 @@ def test_sample_recovered_tokens_pytorch_ngram(self):
assert output_token_ids[0].item() == 0
assert output_token_ids[1].item() == 1

@patch('torch.tensor', new=mock_tensor_pin_memory)
def test_sample_recovered_tokens_pytorch_autoregressive(self):
"""Test recovered token sampling for autoregressive models"""
output_token_ids = torch.empty(2, dtype=torch.int32)
Expand Down
259 changes: 178 additions & 81 deletions vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,32 @@ def expand_batch_to_tokens(
def sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens]
draft_token_ids: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
# NOTE(woosuk): Create only one distribution for each request.
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]

q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()

num_draft_tensor = torch.tensor(num_draft_tokens,
pin_memory=True).to(device,
non_blocking=True)
has_draft_mask = num_draft_tensor > 0

for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
temp_q = torch.empty_like(q[i])
temp_q.exponential_(generator=generator)
q[i] = torch.where(has_draft_mask[i], temp_q, q[i])

recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_pytorch(
Expand Down Expand Up @@ -409,44 +410,104 @@ def rejection_random_sample_pytorch(
IS_NGRAM=False,
):
batch_size = output_token_ids.shape[0]
device = output_token_ids.device

for req_idx in range(batch_size):
if is_greedy[req_idx]:
continue
zero_cpu = torch.tensor([0], pin_memory=True)
zero_device = zero_cpu.to(device, non_blocking=True)

if req_idx == 0:
start_idx = 0
else:
start_idx = cu_num_draft_tokens[req_idx - 1].item()
end_idx = cu_num_draft_tokens[req_idx].item()
num_draft_tokens = end_idx - start_idx
cu_start = torch.cat([zero_device, cu_num_draft_tokens[:-1]])
cu_end = cu_num_draft_tokens
num_draft_per_batch = cu_end - cu_start

max_draft_len = max_spec_len
pos_indices_cpu = torch.arange(max_draft_len, pin_memory=True)
pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :]

valid_mask = pos_indices < num_draft_per_batch[:, None]
global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp(
0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[
global_token_indices] # [batch_size, max_draft_len]

if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to(
device, non_blocking=True).expand_as(draft_tokens)
else:
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
flat_draft_probs = draft_probs[flat_indices, flat_draft_tokens]
draft_token_probs = flat_draft_probs.view(batch_size, max_draft_len)

flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
flat_target_probs = target_probs[flat_indices, flat_draft_tokens]
target_token_probs = flat_target_probs.view(batch_size, max_draft_len)

uniform_token_probs = uniform_probs[global_token_indices]
recovered_tokens = recovered_token_ids[global_token_indices]

zero_threshold_cpu = torch.tensor([0.0],
pin_memory=True,
dtype=torch.float32)
zero_threshold = zero_threshold_cpu.to(device, non_blocking=True)

acceptance_condition = (draft_token_probs > zero_threshold) & (
target_token_probs / draft_token_probs >= uniform_token_probs)

first_rejection = (~acceptance_condition) & valid_mask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

400-500 lines of code without comments. It is recommended to add comments for core functionalities to improve code readability.


default_pos_cpu = torch.full([batch_size, 1],
max_draft_len,
pin_memory=True)
default_pos = default_pos_cpu.to(device, non_blocking=True)

first_reject_pos = torch.where(
first_rejection.any(dim=1, keepdim=True),
first_rejection.float().argmax(dim=1, keepdim=True), default_pos)
pos_mask = pos_indices >= first_reject_pos
should_skip = pos_mask & valid_mask

final_acceptance = acceptance_condition & (~should_skip)
non_greedy_mask = ~is_greedy
update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip)

first_reject_mask = (pos_indices == first_reject_pos
) & valid_mask & non_greedy_mask[:, None]
final_update_mask = update_mask | first_reject_mask
final_tokens = torch.where(
first_reject_mask, recovered_tokens,
torch.where(final_acceptance, draft_tokens,
output_token_ids[:, :max_draft_len]))

output_token_ids[:, :max_draft_len] = torch.where(
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len])

no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch
should_add_bonus = non_greedy_mask & no_rejection

rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = draft_token_ids[start_idx + pos].item()
bonus_positions = num_draft_per_batch # [batch_size]

if IS_NGRAM:
draft_prob = 1.0
else:
draft_prob = draft_probs[start_idx + pos,
draft_token_id].item()
seq_len = output_token_ids.shape[1]
all_positions_cpu = torch.arange(seq_len, pin_memory=True)
all_positions = all_positions_cpu.to(
device, non_blocking=True)[None, :] # [1, seq_len]

target_prob = target_probs[start_idx + pos,
draft_token_id].item()
uniform_prob = uniform_probs[start_idx + pos].item()
batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1]

if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
token_id = draft_token_id
else:
rejected = True
token_id = recovered_token_ids[start_idx + pos].item()
max_spec_len_cpu = torch.tensor([max_spec_len], pin_memory=True)
max_spec_len_device = max_spec_len_cpu.to(device, non_blocking=True)

output_token_ids[req_idx, pos] = token_id
valid_bonus_pos = bonus_positions < (max_spec_len_device + 1)
final_bonus_mask = should_add_bonus & valid_bonus_pos

if not rejected:
bonus_token_id = bonus_token_ids[req_idx].item()
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
bonus_pos_match = (all_positions == batch_bonus_positions)
bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None]

bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len)
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded,
output_token_ids)


def expand_pytorch(
Expand All @@ -457,21 +518,35 @@ def expand_pytorch(
replace_to,
MAX_NUM_TOKENS,
):
batch_size = len(input_ptr)
# no loop optimization version
device = cu_num_tokens_ptr.device
batch_size = input_ptr.shape[0]
num_tokens = output_ptr.shape[0]

if batch_size == 0 or num_tokens == 0:
return

cu_start = torch.cat([
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_tokens_ptr[:-1]
])
cu_end = cu_num_tokens_ptr

for req_idx in range(batch_size):
start_idx = 0 if req_idx == 0 else cu_num_tokens_ptr[req_idx - 1]
end_idx = cu_num_tokens_ptr[req_idx]
num_tokens = end_idx - start_idx
token_indices = torch.arange(num_tokens,
device=device)[:, None] # [num_tokens, 1]
cu_start_exp = cu_start[None, :] # [1, batch_size]
cu_end_exp = cu_end[None, :] # [1, batch_size]

src_val = input_ptr[req_idx]
src_val = replace_to if src_val == replace_from else src_val
in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp)

offset = torch.arange(MAX_NUM_TOKENS, device=num_tokens.device)
mask = offset < num_tokens
replaced_input = torch.where(input_ptr == replace_from, replace_to,
input_ptr).float()

output_slice = start_idx + offset[mask]
output_ptr[output_slice] = src_val
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)

needs_update = in_range.any(dim=1)

output_ptr[:] = torch.where(needs_update, token_values, output_ptr)


def sample_recovered_tokens_pytorch(
Expand All @@ -484,37 +559,59 @@ def sample_recovered_tokens_pytorch(
vocab_size,
IS_NGRAM=False,
):
batch_size = len(cu_num_draft_tokens)

for req_idx in range(batch_size):
start_idx = 0 if req_idx == 0 else cu_num_draft_tokens[req_idx - 1]
end_idx = cu_num_draft_tokens[req_idx]
num_draft_tokens = end_idx - start_idx

for pos in range(num_draft_tokens):
token_idx = start_idx + pos

if IS_NGRAM:
draft_token_id = draft_token_ids[token_idx]
orig_prob = target_probs[token_idx, draft_token_id].item()
target_probs[token_idx, draft_token_id] = 0
prob = target_probs[token_idx].clone()
else:
draft_p = draft_probs[token_idx].clone()
target_p = target_probs[token_idx].clone()
prob = torch.maximum(target_p - draft_p,
torch.tensor(0.0, device=target_p.device))

q_values = torch.full((vocab_size, ),
float('-inf'),
device=q.device)
q_values[:vocab_size] = q[req_idx, :vocab_size]

recovered_id = torch.argmax(prob / q_values).item()
output_token_ids[token_idx] = recovered_id

if IS_NGRAM:
target_probs[token_idx, draft_token_id] = orig_prob
device = output_token_ids.device
num_tokens = output_token_ids.shape[0]

if num_tokens == 0:
return

cu_start = torch.cat([
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_draft_tokens[:-1],
])
cu_end = cu_num_draft_tokens

token_indices = torch.arange(num_tokens, device=device) # [num_tokens]

token_indices_expanded = token_indices[:, None] # [num_tokens, 1]
cu_start_expanded = cu_start[None, :] # [1, batch_size]
cu_end_expanded = cu_end[None, :] # [1, batch_size]

in_range_mask = (token_indices_expanded >= cu_start_expanded) & (
token_indices_expanded < cu_end_expanded)

token_to_batch = torch.argmax(in_range_mask.int(), dim=1)

has_match = in_range_mask.any(dim=1)
token_to_batch = torch.where(has_match, token_to_batch, 0)

if IS_NGRAM:
token_indices = torch.arange(num_tokens, device=device)

modified_target_probs = target_probs.clone()
modified_target_probs[token_indices, draft_token_ids] = 0
prob = modified_target_probs

else:
prob = torch.maximum(
target_probs - draft_probs,
torch.tensor(0.0, pin_memory=True).to(device, non_blocking=True),
)

q_values = q[token_to_batch] # [num_tokens, vocab_size]

epsilon = 1e-10
q_values_safe = torch.where(q_values == 0, epsilon, q_values)
q_values_safe = torch.where(torch.isinf(q_values), epsilon, q_values_safe)

prob_over_q = prob / q_values_safe

prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10,
prob_over_q)

recovered_ids = torch.argmax(prob_over_q, dim=1)

output_token_ids[:] = recovered_ids


rs.expand_batch_to_tokens = expand_batch_to_tokens
Loading