Skip to content

Commit 938a816

Browse files
authored
[AsyncScheduling] Don't schedule past request max_tokens (#27922)
Signed-off-by: Nick Hill <[email protected]>
1 parent c9f66da commit 938a816

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

tests/v1/core/test_async_scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,20 @@ def test_stop_by_max_tokens(max_tokens: int):
3434
requests = create_requests(num_requests=2, max_tokens=max_tokens)
3535
req0, req1 = requests
3636

37+
expected_total_num_scheduled_tokens = 0
3738
sched_outputs: deque[SchedulerOutput] = deque()
3839
scheduler.add_request(req0)
3940
sched_outputs.append(scheduler.schedule())
41+
expected_total_num_scheduled_tokens += req0.num_prompt_tokens + max_tokens - 1
4042

4143
scheduler.add_request(req1)
4244
sched_outputs.append(scheduler.schedule())
45+
expected_total_num_scheduled_tokens += req1.num_prompt_tokens + max_tokens - 1
4346

47+
total_num_scheduled_tokens = 0
4448
while sched_outputs:
4549
sched_output = sched_outputs.popleft()
50+
total_num_scheduled_tokens += sched_output.total_num_scheduled_tokens
4651
model_runner_output = _make_model_runner_output(sched_output)
4752
scheduler.update_from_output(sched_output, model_runner_output)
4853

@@ -53,6 +58,8 @@ def test_stop_by_max_tokens(max_tokens: int):
5358
assert scheduler.get_num_unfinished_requests() == 0
5459
assert req0.num_output_tokens == max_tokens
5560
assert req1.num_output_tokens == max_tokens
61+
# Ensure we aren't scheduling more tokens than necessary.
62+
assert total_num_scheduled_tokens == expected_total_num_scheduled_tokens
5663

5764

5865
def test_abort():

tests/v1/e2e/test_spec_decode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def test_suffix_decoding_acceptance(
155155
)
156156

157157
# Run several times and check that the accepted tokens increase.
158-
spec_llm.chat(test_prompts, sampling_config)
159158
num_draft = []
160159
num_accept = []
161160
for i in range(10): # Run multiple times to warm up the cache.

vllm/v1/core/sched/scheduler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,14 @@ def schedule(self) -> SchedulerOutput:
217217
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
218218
num_new_tokens = min(num_new_tokens, token_budget)
219219

220-
# Make sure the input position does not exceed the max model len.
221-
# This is necessary when using spec decoding.
220+
# Make sure the input position does not exceed the max model len or
221+
# request's max_tokens.
222+
# This is necessary when using spec decoding and/or async scheduling.
223+
max_total_tokens = min(
224+
request.num_prompt_tokens + request.max_tokens, self.max_model_len
225+
)
222226
num_new_tokens = min(
223-
num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
227+
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
224228
)
225229

226230
# Schedule encoder inputs.

0 commit comments

Comments
 (0)