Skip to content

Commit 8bb0284

Browse files
Fixed the performance degradation issue in post-processing in speculative decoding scenarios. (#4849)
… ### What this PR does / why we need it? When speculative decoding is enabled and temperature > 0, bonus_logits and target_logits are sampled separately: 1. bonus_logits are sampled using a fused torch_npu.npu_top_k_top_p operator invoked inside the main sampler, 2. while target_logits are sampled within the rejection sampler using a less-optimized implementation composed of smaller operators. Consequently, the cumsum operation in the top-p sampling for target_logits becomes especially time-consuming, leading to performance degradation. <img width="1029" height="623" alt="image" src="https://github.com/user-attachments/assets/1969f561-6aa5-41b3-9a87-1f64d4321cbf" /> Apply the fused operator to the sampling of target_logits as well to reduce overhead <img width="1039" height="572" alt="image" src="https://github.com/user-attachments/assets/1e6563da-3418-405d-b657-7bbe10dd0924" /> ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: funanyang <[email protected]> Co-authored-by: weijinqian0 <[email protected]>
1 parent 5b179c5 commit 8bb0284

File tree

1 file changed

+68
-1
lines changed

1 file changed

+68
-1
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33

44
import torch
55
import torch.nn as nn
6+
import torch_npu
67
import vllm.v1.sample.rejection_sampler as rs
78
from vllm.triton_utils import HAS_TRITON, tl, triton
89
from vllm.v1.sample.metadata import SamplingMetadata
10+
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
911
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
10-
apply_sampling_constraints,
1112
generate_uniform_probs)
1213
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1314

15+
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
16+
1417
PLACEHOLDER_TOKEN_ID = -1
1518
GREEDY_TEMPERATURE = -1
1619
# Maximum number of speculative draft tokens allowed per request in a single
@@ -104,6 +107,70 @@ def forward(
104107
return output_token_ids
105108

106109

110+
def apply_sampling_constraints(
111+
logits: torch.Tensor, # [num_tokens, vocab_size]
112+
cu_num_draft_tokens: torch.Tensor, # [batch_size]
113+
sampling_metadata: SamplingMetadata,
114+
) -> torch.Tensor:
115+
"""Process logits based on sampling metadata.
116+
117+
This function applies temperature scaling to the logits,
118+
as well as top-k and top-p. For greedy decoding, it returns
119+
the original logits.
120+
121+
Args:
122+
logits: Input logits tensor to be processed.
123+
cu_num_draft_tokens: Cumulative number of draft tokens.
124+
sampling_metadata: Metadata containing sampling parameters such as
125+
temperature and whether greedy sampling is used.
126+
127+
Returns:
128+
torch.Tensor: Processed logits if non-greedy sampling is used,
129+
otherwise returns the original logits.
130+
"""
131+
assert logits.ndim == 2
132+
assert cu_num_draft_tokens.ndim == 1
133+
if sampling_metadata.all_greedy:
134+
return logits
135+
136+
num_tokens = logits.shape[0]
137+
temperature = expand_batch_to_tokens(
138+
sampling_metadata.temperature,
139+
cu_num_draft_tokens,
140+
num_tokens,
141+
replace_from=GREEDY_TEMPERATURE,
142+
replace_to=1,
143+
)
144+
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
145+
logits.div_(temperature.unsqueeze(-1))
146+
147+
# Get expanded top_k and top_p tensors.
148+
top_k = None
149+
if sampling_metadata.top_k is not None:
150+
top_k = expand_batch_to_tokens(
151+
sampling_metadata.top_k,
152+
cu_num_draft_tokens,
153+
num_tokens,
154+
)
155+
top_p = None
156+
if sampling_metadata.top_p is not None:
157+
top_p = expand_batch_to_tokens(
158+
sampling_metadata.top_p,
159+
cu_num_draft_tokens,
160+
num_tokens,
161+
)
162+
163+
if get_ascend_device_type(
164+
) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int(
165+
top_k.max()) <= 1024:
166+
return torch_npu.npu_top_k_top_p(logits, top_p.to(torch.bfloat16),
167+
top_k)
168+
else:
169+
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
170+
# which is slow for large vocab sizes. This may cause performance issues.
171+
return apply_top_k_top_p(logits, top_k, top_p)
172+
173+
107174
def rejection_sample(
108175
# [num_tokens]
109176
draft_token_ids: torch.Tensor,

0 commit comments

Comments
 (0)