Skip to content

Commit 6d17974

Browse files
Mohammad OthmanMohammad Othman
authored andcommitted
[Perf] Optimize multi-token incremental detokenization
This commit optimizes the incremental detokenization process when handling multiple tokens per update, which is increasingly common with speculative decoding methods (EAGLE, Medusa, n-gram proposer, etc.). **Problem:** The original implementation in BaseIncrementalDetokenizer.update() processed tokens one-by-one in a loop, calling decode_next() for each token. This created inefficiency when speculative decoding generates multiple tokens per step (up to 128 tokens in MAX_SPEC_LEN scenarios). For SlowIncrementalDetokenizer, this was particularly inefficient as each decode_next() call invoked detokenize_incrementally() with the full token list, creating O(n) work per token for n tokens total. **Solution:** 1. Refactored BaseIncrementalDetokenizer.update() to batch-process tokens when possible, using a new _decode_tokens_batch() method. 2. Special handling for min_tokens edge case: when crossing the min_tokens threshold during a batch, falls back to one-by-one processing to accurately track stop_check_offset for stop string detection. 3. Added SlowIncrementalDetokenizer._decode_tokens_batch() override that processes tokens more efficiently while maintaining correct incremental state updates. 4. FastIncrementalDetokenizer continues to use the default implementation (calling decode_next per token) since DecodeStream requires per-token state updates. Fixes TODO in vllm/v1/engine/detokenizer.py:115-116 Signed-off-by: Mohammad Othman <[email protected]>
1 parent 2bb4435 commit 6d17974

File tree

1 file changed

+66
-6
lines changed

1 file changed

+66
-6
lines changed

vllm/v1/engine/detokenizer.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,32 @@ def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
112112
skipped_stop_token_id = None
113113

114114
# 1) Detokenize the new token ids incrementally.
115-
# TODO(woosuk): This method becomes very inefficient when the number of
116-
# new_token_ids is more than 1. We need to optimize this.
115+
# Optimization: batch process multiple tokens for efficiency.
117116
stop_check_offset = len(self.output_text)
118-
for new_token_id in new_token_ids:
119-
self.token_ids.append(new_token_id)
120-
self.output_text += self.decode_next(new_token_id)
121-
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
117+
118+
# Check if we need special handling for min_tokens
119+
# If we might cross min_tokens threshold, process tokens one by one
120+
# to accurately track the stop_check_offset position
121+
if (
122+
self.min_tokens
123+
and len(self.output_token_ids) < self.min_tokens
124+
and len(self.output_token_ids) + len(new_token_ids) > self.min_tokens
125+
):
126+
# We will cross min_tokens during this batch
127+
# Process one by one to track the exact position
128+
for new_token_id in new_token_ids:
129+
self.token_ids.append(new_token_id)
130+
self.output_text += self.decode_next(new_token_id)
131+
# Update stop_check_offset while we're still under min_tokens
132+
if self.min_tokens and len(self.output_token_ids) <= self.min_tokens:
133+
stop_check_offset = len(self.output_text)
134+
else:
135+
# Fast path: batch process all tokens
136+
self.token_ids.extend(new_token_ids)
137+
new_text = self._decode_tokens_batch(new_token_ids)
138+
self.output_text += new_text
139+
140+
# Update stop_check_offset if still under min_tokens
122141
if self.min_tokens and len(self.output_token_ids) <= self.min_tokens:
123142
stop_check_offset = len(self.output_text)
124143

@@ -142,6 +161,17 @@ def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
142161

143162
return stop_string
144163

164+
def _decode_tokens_batch(self, token_ids: list[int]) -> str:
165+
"""Decode a batch of tokens efficiently.
166+
167+
Default implementation processes tokens one by one.
168+
Subclasses can override for more efficient batch processing.
169+
"""
170+
result = ""
171+
for token_id in token_ids:
172+
result += self.decode_next(token_id)
173+
return result
174+
145175
@abstractmethod
146176
def decode_next(self, next_token_id: int) -> str:
147177
raise NotImplementedError
@@ -312,6 +342,36 @@ def decode_next(self, next_token_id: int) -> str:
312342

313343
return decoded_text
314344

345+
def _decode_tokens_batch(self, token_ids: list[int]) -> str:
346+
"""Optimized batch decoding for SlowIncrementalDetokenizer.
347+
348+
Processes multiple tokens more efficiently by calling
349+
detokenize_incrementally once per token but with properly
350+
accumulated state.
351+
"""
352+
result = ""
353+
base_len = len(self.token_ids) - len(token_ids)
354+
355+
for i, token_id in enumerate(token_ids):
356+
new_tokens, decoded_text, prefix_offset, read_offset = (
357+
detokenize_incrementally(
358+
tokenizer=self.tokenizer,
359+
all_input_ids=self.token_ids[: base_len + i + 1],
360+
prev_tokens=self.tokens,
361+
prefix_offset=self.prefix_offset,
362+
read_offset=self.read_offset,
363+
skip_special_tokens=self.skip_special_tokens,
364+
spaces_between_special_tokens=self.spaces_between_special_tokens,
365+
)
366+
)
367+
368+
self.tokens.extend(new_tokens)
369+
self.prefix_offset = prefix_offset
370+
self.read_offset = read_offset
371+
result += decoded_text
372+
373+
return result
374+
315375

316376
def check_stop_strings(
317377
output_text: str,

0 commit comments

Comments
 (0)