Skip to content

Commit f3d9c3c

Browse files
committed
[BugFix] Fix async scheduling + chunked prefill + preemption
Signed-off-by: Nick Hill <[email protected]>
1 parent 2bb4435 commit f3d9c3c

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

tests/v1/e2e/test_async_scheduling.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,8 @@ def test_without_spec_decoding(
6565
(True, "mp", True, None, False),
6666
(True, "uni", True, None, False),
6767
(False, "mp", True, None, True),
68-
# Async scheduling + preemption + chunked prefill needs to be fixed (WIP)
69-
# (True, "mp", True, None, True),
70-
# (True, "uni", True, None, True),
68+
(True, "mp", True, None, True),
69+
(True, "uni", True, None, True),
7170
]
7271

7372
run_tests(
@@ -103,9 +102,8 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
103102
(False, "mp", True, spec_config_short, True),
104103
(True, "uni", True, spec_config, False),
105104
(True, "uni", True, spec_config_short, False),
106-
# Async scheduling + preemption + chunked prefill needs to be fixed (WIP)
107-
# (True, "mp", True, spec_config, True),
108-
# (True, "uni", True, spec_config_short, True),
105+
(True, "mp", True, spec_config, True),
106+
(True, "uni", True, spec_config_short, True),
109107
]
110108

111109
run_tests(

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -778,9 +778,7 @@ def _make_cached_request_data(
778778
assert not scheduled_in_prev_step
779779
resumed_req_ids.add(req_id)
780780
if not scheduled_in_prev_step:
781-
all_token_ids[req_id] = req.all_token_ids[
782-
: req.num_computed_tokens + num_tokens
783-
]
781+
all_token_ids[req_id] = req.all_token_ids.copy()
784782
new_block_ids.append(
785783
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
786784
)

vllm/v1/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def __len__(self):
9797
def __repr__(self):
9898
return f"ConstantList({self._x})"
9999

100+
def copy(self) -> list[T]:
101+
return self._x.copy()
102+
100103

101104
class CpuGpuBuffer:
102105
"""Buffer to easily copy tensors between CPU and GPU."""

0 commit comments

Comments
 (0)