33import dataclasses
44from unittest .mock import Mock
55
6- import numpy as np
76import pytest
87import torch
98
@@ -170,7 +169,7 @@ def test_schedule_partial_requests():
170169 req_id_to_index = req_to_index ,
171170 # Only the first request has a sampled token id because
172171 # the rest requests are still being prefilled.
173- sampled_token_ids = [np . array ( [0 ]), np . array ([]), np . array ([]) ],
172+ sampled_token_ids = [[0 ], [], [] ],
174173 logprobs = None ,
175174 prompt_logprobs_dict = {},
176175 pooler_output = [],
@@ -217,7 +216,7 @@ def test_no_mm_input_chunking():
217216 model_runner_output = ModelRunnerOutput (
218217 req_ids = [request .request_id for request in requests ],
219218 req_id_to_index = req_to_index ,
220- sampled_token_ids = [np . array ([]) for _ in range (len (requests ))],
219+ sampled_token_ids = [[] for _ in range (len (requests ))],
221220 logprobs = None ,
222221 prompt_logprobs_dict = {},
223222 pooler_output = [],
@@ -277,7 +276,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
277276 model_runner_output = ModelRunnerOutput (
278277 req_ids = [request .request_id for request in requests ],
279278 req_id_to_index = req_to_index ,
280- sampled_token_ids = [np . array ([]) for _ in range (len (requests ))],
279+ sampled_token_ids = [[] for _ in range (len (requests ))],
281280 logprobs = None ,
282281 prompt_logprobs_dict = {},
283282 pooler_output = [],
@@ -301,8 +300,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
301300 model_runner_output = ModelRunnerOutput (
302301 req_ids = [request .request_id for request in requests ],
303302 req_id_to_index = req_to_index ,
304- sampled_token_ids = [np .array ([0 ]), np .array ([0 ])]
305- + [np .array ([]) for _ in range (len (requests ) - 2 )],
303+ sampled_token_ids = [[0 ], [0 ]] + [[] for _ in range (len (requests ) - 2 )],
306304 logprobs = None ,
307305 prompt_logprobs_dict = {},
308306 pooler_output = [],
@@ -349,8 +347,8 @@ def test_stop_via_update_from_output():
349347 req_ids = [req .request_id for req in requests ],
350348 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
351349 sampled_token_ids = [
352- np . array ( [EOS_TOKEN_ID ]) ,
353- np . array ( [10 , 11 ]) ,
350+ [EOS_TOKEN_ID ],
351+ [10 , 11 ],
354352 ], # First request hits EOS, second continues
355353 logprobs = None ,
356354 prompt_logprobs_dict = {},
@@ -394,10 +392,7 @@ def test_stop_via_update_from_output():
394392 model_output = ModelRunnerOutput (
395393 req_ids = [req .request_id for req in requests ],
396394 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
397- sampled_token_ids = [
398- np .array ([10 , 42 , 12 ]),
399- np .array ([13 , 14 ]),
400- ], # First request hits stop token
395+ sampled_token_ids = [[10 , 42 , 12 ], [13 , 14 ]], # First request hits stop token
401396 logprobs = None ,
402397 prompt_logprobs_dict = {},
403398 pooler_output = [],
@@ -441,10 +436,7 @@ def test_stop_via_update_from_output():
441436 model_output = ModelRunnerOutput (
442437 req_ids = [req .request_id for req in requests ],
443438 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
444- sampled_token_ids = [
445- np .array ([10 , 11 , 12 ]),
446- np .array ([13 ]),
447- ], # First request exceeds max_tokens
439+ sampled_token_ids = [[10 , 11 , 12 ], [13 ]], # First request exceeds max_tokens
448440 logprobs = None ,
449441 prompt_logprobs_dict = {},
450442 pooler_output = [],
@@ -483,7 +475,7 @@ def test_stop_via_update_from_output():
483475 model_output = ModelRunnerOutput (
484476 req_ids = [requests [0 ].request_id ],
485477 req_id_to_index = {requests [0 ].request_id : 0 },
486- sampled_token_ids = [np . array ( [EOS_TOKEN_ID , 10 , 11 ]) ],
478+ sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
487479 logprobs = None ,
488480 prompt_logprobs_dict = {},
489481 pooler_output = [],
@@ -624,7 +616,7 @@ def test_schedule_concurrent_batches(
624616 model_runner_output = ModelRunnerOutput (
625617 req_ids = [requests [0 ].request_id ],
626618 req_id_to_index = {requests [0 ].request_id : 0 },
627- sampled_token_ids = [np . array ( [0 ]) ],
619+ sampled_token_ids = [[0 ]],
628620 logprobs = None ,
629621 prompt_logprobs_dict = {},
630622 pooler_output = [],
@@ -641,7 +633,7 @@ def test_schedule_concurrent_batches(
641633 model_runner_output = ModelRunnerOutput (
642634 req_ids = [requests [1 ].request_id ],
643635 req_id_to_index = {requests [1 ].request_id : 0 },
644- sampled_token_ids = [np . array ( [0 ]) ],
636+ sampled_token_ids = [[0 ]],
645637 logprobs = None ,
646638 prompt_logprobs_dict = {},
647639 pooler_output = [],
@@ -678,7 +670,7 @@ def test_preempt_during_execution():
678670 model_runner_output0 = ModelRunnerOutput (
679671 req_ids = [requests [0 ].request_id ],
680672 req_id_to_index = {requests [0 ].request_id : 0 },
681- sampled_token_ids = [np . array ( [0 ]) ],
673+ sampled_token_ids = [[0 ]],
682674 logprobs = None ,
683675 prompt_logprobs_dict = {},
684676 pooler_output = [],
@@ -695,7 +687,7 @@ def test_preempt_during_execution():
695687 model_runner_output1 = ModelRunnerOutput (
696688 req_ids = [requests [1 ].request_id ],
697689 req_id_to_index = {requests [1 ].request_id : 0 },
698- sampled_token_ids = [np . array ( [42 ]) ],
690+ sampled_token_ids = [[42 ]],
699691 logprobs = None ,
700692 prompt_logprobs_dict = {},
701693 pooler_output = [],
@@ -712,18 +704,14 @@ def test_preempt_during_execution():
712704@pytest .mark .parametrize (
713705 "spec_tokens,output_tokens,expected" ,
714706 [
715- ([[1 , 2 , 3 ]], [np .array ([1 , 2 , 3 , 4 ])], (1 , 3 , 3 , [1 , 1 , 1 ])), # perfect match
716- ([[1 , 2 , 3 ]], [np .array ([1 , 5 ])], (1 , 3 , 1 , [1 , 0 , 0 ])), # early mismatch
717- (
718- [[1 , 2 ], [3 ]],
719- [np .array ([1 , 2 , 5 ]), np .array ([3 , 4 ])],
720- (2 , 3 , 3 , [2 , 1 ]),
721- ), # multiple sequences
722- ([[1 ]], [np .array ([1 , 2 ])], (1 , 1 , 1 , [1 ])), # single token sequence
723- ([[]], [np .array ([5 ])], (0 , 0 , 0 , [0 ])), # empty sequence
707+ ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (1 , 3 , 3 , [1 , 1 , 1 ])), # perfect match
708+ ([[1 , 2 , 3 ]], [[1 , 5 ]], (1 , 3 , 1 , [1 , 0 , 0 ])), # early mismatch
709+ ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]], (2 , 3 , 3 , [2 , 1 ])), # multiple sequences
710+ ([[1 ]], [[1 , 2 ]], (1 , 1 , 1 , [1 ])), # single token sequence
711+ ([[]], [[5 ]], (0 , 0 , 0 , [0 ])), # empty sequence
724712 (
725713 [[1 , 2 , 3 ], [4 , 5 , 6 ]],
726- [np . array ( [1 , 2 , 7 ]), np . array ( [4 , 8 ]) ],
714+ [[1 , 2 , 7 ], [4 , 8 ]],
727715 (2 , 6 , 3 , [2 , 1 , 0 ]),
728716 ), # multiple mismatches
729717 ],
@@ -757,7 +745,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
757745 model_runner_output = ModelRunnerOutput (
758746 req_ids = req_ids ,
759747 req_id_to_index = req_to_index ,
760- sampled_token_ids = [np . array ( [0 ]) for _ in range (len (requests ))],
748+ sampled_token_ids = [[0 ] for _ in range (len (requests ))],
761749 logprobs = None ,
762750 prompt_logprobs_dict = {},
763751 pooler_output = [],
@@ -984,7 +972,7 @@ def test_kv_connector_basic(is_async: bool):
984972 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
985973 req_ids = req_ids ,
986974 req_id_to_index = req_to_index ,
987- sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
975+ sampled_token_ids = [[1000 ]] * len (req_ids ),
988976 logprobs = None ,
989977 prompt_logprobs_dict = {},
990978 pooler_output = [],
@@ -1037,7 +1025,7 @@ def test_kv_connector_basic(is_async: bool):
10371025 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
10381026 req_ids = req_ids ,
10391027 req_id_to_index = req_to_index ,
1040- sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
1028+ sampled_token_ids = [[1000 ]] * len (req_ids ),
10411029 logprobs = None ,
10421030 prompt_logprobs_dict = {},
10431031 pooler_output = [],
@@ -1100,7 +1088,7 @@ def test_external_prefix_cache_metrics():
11001088 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
11011089 req_ids = [r .request_id for r in requests ],
11021090 req_id_to_index = {r .request_id : i for i , r in enumerate (requests )},
1103- sampled_token_ids = [np . array ( [1000 ]) ] * NUM_REQUESTS ,
1091+ sampled_token_ids = [[1000 ]] * NUM_REQUESTS ,
11041092 logprobs = None ,
11051093 prompt_logprobs_dict = {},
11061094 pooler_output = [],
@@ -1166,7 +1154,7 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
11661154 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
11671155 req_ids = req_ids ,
11681156 req_id_to_index = req_to_index ,
1169- sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
1157+ sampled_token_ids = [[1000 ]] * len (req_ids ),
11701158 logprobs = None ,
11711159 prompt_logprobs_dict = {},
11721160 pooler_output = [],
@@ -1251,7 +1239,7 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
12511239 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
12521240 req_ids = req_ids ,
12531241 req_id_to_index = req_to_index ,
1254- sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
1242+ sampled_token_ids = [[1000 ]] * len (req_ids ),
12551243 logprobs = None ,
12561244 prompt_logprobs_dict = {},
12571245 pooler_output = [],
@@ -1344,7 +1332,7 @@ def make_output(scheduler: Scheduler):
13441332 return ModelRunnerOutput (
13451333 req_ids = [req .request_id for req in scheduler .running ],
13461334 req_id_to_index = {req .request_id : i for i , req in enumerate (scheduler .running )},
1347- sampled_token_ids = [np . array ( [1000 ]) ] * len (scheduler .running ),
1335+ sampled_token_ids = [[1000 ]] * len (scheduler .running ),
13481336 logprobs = None ,
13491337 prompt_logprobs_dict = {},
13501338 pooler_output = [],
@@ -1761,7 +1749,7 @@ def test_priority_scheduling_preemption():
17611749 req_id_to_index = {
17621750 req .request_id : i for i , req in enumerate (low_priority_requests )
17631751 },
1764- sampled_token_ids = [np . array ( [100 ]) for _ in low_priority_requests ],
1752+ sampled_token_ids = [[100 ] for _ in low_priority_requests ],
17651753 logprobs = None ,
17661754 prompt_logprobs_dict = {},
17671755 pooler_output = [],
@@ -1830,7 +1818,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
18301818 req_id_to_index = {
18311819 req .request_id : i for i , req in enumerate (low_priority_requests )
18321820 },
1833- sampled_token_ids = [np . array ( [100 ]) for _ in low_priority_requests ],
1821+ sampled_token_ids = [[100 ] for _ in low_priority_requests ],
18341822 logprobs = None ,
18351823 prompt_logprobs_dict = {},
18361824 pooler_output = [],
@@ -2076,7 +2064,7 @@ def test_priority_scheduling_heap_property():
20762064 model_output = ModelRunnerOutput (
20772065 req_ids = [req .req_id ],
20782066 req_id_to_index = {req .req_id : 0 },
2079- sampled_token_ids = [np . array ( [100 ]) ],
2067+ sampled_token_ids = [[100 ]],
20802068 logprobs = None ,
20812069 prompt_logprobs_dict = {},
20822070 pooler_output = [],
@@ -2162,7 +2150,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
21622150 model_output = ModelRunnerOutput (
21632151 req_ids = [request_low .request_id ],
21642152 req_id_to_index = {request_low .request_id : 0 },
2165- sampled_token_ids = [np . array ( [100 ]) ],
2153+ sampled_token_ids = [[100 ]],
21662154 # spec_token_ids=None,
21672155 logprobs = None ,
21682156 prompt_logprobs_dict = {},
@@ -2193,7 +2181,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
21932181 model_output = ModelRunnerOutput (
21942182 req_ids = [req .request_id for req in requests ],
21952183 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
2196- sampled_token_ids = [np . array ( [100 ]) for _ in requests ],
2184+ sampled_token_ids = [[100 ] for _ in requests ],
21972185 # spec_token_ids=None,
21982186 logprobs = None ,
21992187 prompt_logprobs_dict = {},
@@ -2219,7 +2207,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(
22192207 model_output = ModelRunnerOutput (
22202208 req_ids = [req .request_id for req in requests ],
22212209 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
2222- sampled_token_ids = [np . array ([]), np . array ( [100 ]) ],
2210+ sampled_token_ids = [[], [100 ]],
22232211 # spec_token_ids=None,
22242212 logprobs = None ,
22252213 prompt_logprobs_dict = {},
@@ -2636,7 +2624,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
26362624 model_output = ModelRunnerOutput (
26372625 req_ids = [request1 .request_id ],
26382626 req_id_to_index = {request1 .request_id : 0 },
2639- sampled_token_ids = [np . array ( [100 ]) ],
2627+ sampled_token_ids = [[100 ]],
26402628 # spec_token_ids=None,
26412629 logprobs = None ,
26422630 prompt_logprobs_dict = {},
@@ -2842,7 +2830,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
28422830 MODEL_RUNNER_OUTPUT = ModelRunnerOutput (
28432831 req_ids = req_ids ,
28442832 req_id_to_index = req_to_index ,
2845- sampled_token_ids = [np . array ( [1000 ]) ] * len (req_ids ),
2833+ sampled_token_ids = [[1000 ]] * len (req_ids ),
28462834 logprobs = None ,
28472835 prompt_logprobs_dict = {},
28482836 pooler_output = [],
@@ -2955,7 +2943,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
29552943 model_output = ModelRunnerOutput (
29562944 req_ids = [request_low .request_id ],
29572945 req_id_to_index = {request_low .request_id : 0 },
2958- sampled_token_ids = [np . array ( [100 ]) ],
2946+ sampled_token_ids = [[100 ]],
29592947 # spec_token_ids=None,
29602948 logprobs = None ,
29612949 prompt_logprobs_dict = {},
@@ -3006,7 +2994,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
30062994 model_output = ModelRunnerOutput (
30072995 req_ids = [req .request_id for req in requests ],
30082996 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
3009- sampled_token_ids = [np . array ( [100 ]) for _ in requests ],
2997+ sampled_token_ids = [[100 ] for _ in requests ],
30102998 # spec_token_ids=None,
30112999 logprobs = None ,
30123000 prompt_logprobs_dict = {},
@@ -3041,7 +3029,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
30413029 model_output = ModelRunnerOutput (
30423030 req_ids = [req .request_id for req in requests ],
30433031 req_id_to_index = {req .request_id : i for i , req in enumerate (requests )},
3044- sampled_token_ids = [np . array ( [100 ]), np . array ( [100 , 200 ]) ],
3032+ sampled_token_ids = [[100 ], [100 , 200 ]],
30453033 # spec_token_ids=None,
30463034 logprobs = None ,
30473035 prompt_logprobs_dict = {},
@@ -3227,7 +3215,7 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
32273215 model_output = ModelRunnerOutput (
32283216 req_ids = [request1 .request_id , request2 .request_id ],
32293217 req_id_to_index = {request1 .request_id : 0 , request2 .request_id : 1 },
3230- sampled_token_ids = [np . array ( [100 ]), np . array ( [121 ]) ],
3218+ sampled_token_ids = [[100 ], [121 ]],
32313219 # spec_token_ids=None,
32323220 logprobs = None ,
32333221 prompt_logprobs_dict = {},
0 commit comments