Skip to content

Commit 752062d

Browse files
committed
[Performance] Improve the inference performance of Eagle3.
vLLM version: v0.11.0 vLLM main: vllm-project/vllm Signed-off-by: liumail202512 <[email protected]>
1 parent 1b137d6 commit 752062d

File tree

1 file changed

+203
-25
lines changed

1 file changed

+203
-25
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 203 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
apply_sampling_constraints,
1010
generate_uniform_probs)
1111
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
12+
import triton
13+
import triton.language as tl
14+
import triton.runtime.driver as driver
1215

16+
import torch_npu._inductor
1317
PLACEHOLDER_TOKEN_ID = -1
1418
GREEDY_TEMPERATURE = -1
1519
# Maximum number of speculative draft tokens allowed per request in a single
@@ -102,6 +106,189 @@ def forward(
102106
)
103107
return output_token_ids
104108

109+
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
110+
@triton.jit(do_not_specialize=["max_spec_len"])
111+
def rejection_greedy_sample_kernel(
112+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
113+
cu_num_draft_tokens_ptr, # [batch_size]
114+
draft_token_ids_ptr, # [num_tokens]
115+
target_argmax_ptr, # [num_tokens]
116+
bonus_token_ids_ptr, # [batch_size]
117+
is_greedy_ptr, # [batch_size] or None
118+
max_spec_len,
119+
):
120+
req_idx = tl.program_id(0)
121+
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
122+
# re-compilation may happen during runtime when is_greedy_ptr is None.
123+
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
124+
if not is_greedy:
125+
# Early exit for non-greedy sampling requests.
126+
return
127+
128+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
129+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
130+
num_draft_tokens = end_idx - start_idx
131+
132+
rejected = False
133+
for pos in range(num_draft_tokens):
134+
if not rejected:
135+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
136+
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
137+
tl.store(
138+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
139+
target_argmax_id,
140+
)
141+
if draft_token_id != target_argmax_id:
142+
# Reject.
143+
rejected = True
144+
145+
if not rejected:
146+
# If all tokens are accepted, append the bonus token.
147+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
148+
tl.store(
149+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
150+
bonus_token_id,
151+
)
152+
153+
154+
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
155+
@triton.jit(do_not_specialize=["max_spec_len"])
156+
def rejection_random_sample_kernel(
157+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
158+
cu_num_draft_tokens_ptr, # [batch_size]
159+
draft_token_ids_ptr, # [num_tokens]
160+
draft_probs_ptr, # [num_tokens, vocab_size] or None
161+
target_probs_ptr, # [num_tokens, vocab_size]
162+
bonus_token_ids_ptr, # [batch_size]
163+
recovered_token_ids_ptr, # [num_tokens]
164+
uniform_probs_ptr, # [num_tokens]
165+
is_greedy_ptr, # [batch_size]
166+
max_spec_len,
167+
vocab_size,
168+
NO_DRAFT_PROBS: tl.constexpr,
169+
):
170+
req_idx = tl.program_id(0)
171+
is_greedy = tl.load(is_greedy_ptr + req_idx)
172+
if is_greedy:
173+
# Early exit for greedy sampling requests.
174+
return
175+
176+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
177+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
178+
num_draft_tokens = end_idx - start_idx
179+
180+
rejected = False
181+
for pos in range(num_draft_tokens):
182+
if not rejected:
183+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
184+
if NO_DRAFT_PROBS:
185+
draft_prob = 1
186+
else:
187+
draft_prob = tl.load(
188+
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
189+
)
190+
target_prob = tl.load(
191+
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
192+
)
193+
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
194+
# NOTE(woosuk): While the draft probability should never be 0,
195+
# we check it to avoid NaNs. If it happens to be 0, we reject.
196+
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
197+
# Accept.
198+
token_id = draft_token_id
199+
else:
200+
# Reject. Use recovered token.
201+
rejected = True
202+
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
203+
tl.store(
204+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
205+
)
206+
207+
if not rejected:
208+
# If all tokens are accepted, append the bonus token.
209+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
210+
tl.store(
211+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
212+
bonus_token_id,
213+
)
214+
215+
216+
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
217+
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
218+
def expand_kernel(
219+
output_ptr, # [num_tokens]
220+
input_ptr, # [batch_size]
221+
cu_num_tokens_ptr, # [batch_size]
222+
replace_from,
223+
replace_to,
224+
MAX_NUM_TOKENS: tl.constexpr,
225+
):
226+
req_idx = tl.program_id(0)
227+
if req_idx == 0: # noqa: SIM108
228+
start_idx = 0
229+
else:
230+
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
231+
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
232+
num_tokens = end_idx - start_idx
233+
234+
src_val = tl.load(input_ptr + req_idx)
235+
src_val = tl.where(src_val == replace_from, replace_to, src_val)
236+
offset = tl.arange(0, MAX_NUM_TOKENS)
237+
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)
238+
239+
240+
@triton.jit
241+
def sample_recovered_tokens_kernel(
242+
output_token_ids_ptr, # [num_tokens]
243+
cu_num_draft_tokens_ptr, # [batch_size]
244+
draft_token_ids_ptr, # [num_tokens]
245+
draft_probs_ptr, # [num_tokens, vocab_size] or None
246+
target_probs_ptr, # [num_tokens, vocab_size]
247+
q_ptr, # [batch_size, vocab_size]
248+
vocab_size,
249+
PADDED_VOCAB_SIZE: tl.constexpr,
250+
NO_DRAFT_PROBS: tl.constexpr,
251+
):
252+
req_idx = tl.program_id(0)
253+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
254+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
255+
num_draft_tokens = end_idx - start_idx
256+
257+
# Early exit for out-of-range positions.
258+
pos = tl.program_id(1)
259+
if pos >= num_draft_tokens:
260+
return
261+
262+
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
263+
if NO_DRAFT_PROBS:
264+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
265+
prob = tl.load(
266+
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
267+
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)),
268+
other=0,
269+
)
270+
else:
271+
draft_prob = tl.load(
272+
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
273+
mask=vocab_offset < vocab_size,
274+
other=0,
275+
)
276+
target_prob = tl.load(
277+
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
278+
mask=vocab_offset < vocab_size,
279+
other=0,
280+
)
281+
prob = tl.maximum(target_prob - draft_prob, 0)
282+
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
283+
# `tl.argmax` will select the maximum value.
284+
285+
q = tl.load(
286+
q_ptr + req_idx * vocab_size + vocab_offset,
287+
mask=vocab_offset < vocab_size,
288+
other=float("-inf"),
289+
)
290+
recovered_id = tl.argmax(prob / q, axis=-1)
291+
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
105292

106293
def rejection_sample(
107294
# [num_tokens]
@@ -149,25 +336,15 @@ def rejection_sample(
149336
if not sampling_metadata.all_random:
150337
# Rejection sampling for greedy sampling requests.
151338
target_argmax = target_probs.argmax(dim=-1)
152-
if min(num_draft_tokens) == 1 and max(
153-
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
154-
rejection_greedy_sample_spec_len_1_pytorch(
155-
output_token_ids,
156-
draft_token_ids,
157-
target_argmax,
158-
bonus_token_ids,
159-
)
160-
else:
161-
rejection_greedy_sample_pytorch(
162-
output_token_ids,
163-
cu_num_draft_tokens,
164-
draft_token_ids,
165-
target_argmax,
166-
bonus_token_ids,
167-
num_draft_tokens,
168-
max_spec_len,
169-
is_greedy,
170-
)
339+
rejection_greedy_sample_kernel[(batch_size,)](
340+
output_token_ids,
341+
cu_num_draft_tokens,
342+
draft_token_ids,
343+
target_argmax,
344+
bonus_token_ids,
345+
is_greedy,
346+
max_spec_len,
347+
)
171348
if sampling_metadata.all_greedy:
172349
return output_token_ids
173350

@@ -194,7 +371,8 @@ def rejection_sample(
194371
)
195372

196373
# Rejection sampling for random sampling requests.
197-
rejection_random_sample_pytorch(
374+
# Rejection sampling for random sampling requests.
375+
rejection_random_sample_kernel[(batch_size,)](
198376
output_token_ids,
199377
cu_num_draft_tokens,
200378
draft_token_ids,
@@ -206,8 +384,7 @@ def rejection_sample(
206384
is_greedy,
207385
max_spec_len,
208386
vocab_size,
209-
IS_NGRAM=draft_probs is None,
210-
# num_warps=1,
387+
NO_DRAFT_PROBS=draft_probs is None,
211388
)
212389
return output_token_ids
213390

@@ -241,7 +418,7 @@ def expand_batch_to_tokens(
241418
batch_size = x.shape[0]
242419
assert cu_num_tokens.shape[0] == batch_size
243420
expanded_x = x.new_empty(num_tokens)
244-
expand_pytorch(
421+
expand_kernel[(batch_size,)](
245422
expanded_x,
246423
x,
247424
cu_num_tokens,
@@ -282,15 +459,16 @@ def sample_recovered_tokens(
282459
q[i].exponential_(generator=generator)
283460

284461
recovered_token_ids = torch.empty_like(draft_token_ids)
285-
sample_recovered_tokens_pytorch(
462+
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
286463
recovered_token_ids,
287464
cu_num_draft_tokens,
288465
draft_token_ids,
289466
draft_probs,
290467
target_probs,
291468
q,
292469
vocab_size,
293-
IS_NGRAM=draft_probs is None,
470+
triton.next_power_of_2(vocab_size),
471+
NO_DRAFT_PROBS=draft_probs is None,
294472
)
295473
return recovered_token_ids
296474

0 commit comments

Comments
 (0)