File tree Expand file tree Collapse file tree 3 files changed +14
-4
lines changed Expand file tree Collapse file tree 3 files changed +14
-4
lines changed Original file line number Diff line number Diff line change @@ -34,15 +34,20 @@ def test_stop_by_max_tokens(max_tokens: int):
3434 requests = create_requests (num_requests = 2 , max_tokens = max_tokens )
3535 req0 , req1 = requests
3636
37+ expected_total_num_scheduled_tokens = 0
3738 sched_outputs : deque [SchedulerOutput ] = deque ()
3839 scheduler .add_request (req0 )
3940 sched_outputs .append (scheduler .schedule ())
41+ expected_total_num_scheduled_tokens += req0 .num_prompt_tokens + max_tokens - 1
4042
4143 scheduler .add_request (req1 )
4244 sched_outputs .append (scheduler .schedule ())
45+ expected_total_num_scheduled_tokens += req1 .num_prompt_tokens + max_tokens - 1
4346
47+ total_num_scheduled_tokens = 0
4448 while sched_outputs :
4549 sched_output = sched_outputs .popleft ()
50+ total_num_scheduled_tokens += sched_output .total_num_scheduled_tokens
4651 model_runner_output = _make_model_runner_output (sched_output )
4752 scheduler .update_from_output (sched_output , model_runner_output )
4853
@@ -53,6 +58,8 @@ def test_stop_by_max_tokens(max_tokens: int):
5358 assert scheduler .get_num_unfinished_requests () == 0
5459 assert req0 .num_output_tokens == max_tokens
5560 assert req1 .num_output_tokens == max_tokens
61+ # Ensure we aren't scheduling more tokens than necessary.
62+ assert total_num_scheduled_tokens == expected_total_num_scheduled_tokens
5663
5764
5865def test_abort ():
Original file line number Diff line number Diff line change @@ -155,7 +155,6 @@ def test_suffix_decoding_acceptance(
155155 )
156156
157157 # Run several times and check that the accepted tokens increase.
158- spec_llm .chat (test_prompts , sampling_config )
159158 num_draft = []
160159 num_accept = []
161160 for i in range (10 ): # Run multiple times to warm up the cache.
Original file line number Diff line number Diff line change @@ -217,10 +217,14 @@ def schedule(self) -> SchedulerOutput:
217217 num_new_tokens = self .scheduler_config .long_prefill_token_threshold
218218 num_new_tokens = min (num_new_tokens , token_budget )
219219
220- # Make sure the input position does not exceed the max model len.
221- # This is necessary when using spec decoding.
220+ # Make sure the input position does not exceed the max model len or
221+ # request's max_tokens.
222+ # This is necessary when using spec decoding and/or async scheduling.
223+ max_total_tokens = min (
224+ request .num_prompt_tokens + request .max_tokens , self .max_model_len
225+ )
222226 num_new_tokens = min (
223- num_new_tokens , self . max_model_len - 1 - request .num_computed_tokens
227+ num_new_tokens , max_total_tokens - 1 - request .num_computed_tokens
224228 )
225229
226230 # Schedule encoder inputs.
You can’t perform that action at this time.
0 commit comments