Skip to content

Commit 141e6a0

Browse files
[Misc] Make reorder batch also separate extends (#27367)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 130aa8c commit 141e6a0

File tree

2 files changed

+164
-45
lines changed

2 files changed

+164
-45
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from dataclasses import dataclass
5+
6+
import numpy as np
7+
import pytest
8+
9+
from vllm.v1.attention.backends.utils import reorder_batch_to_split_decodes_and_prefills
10+
11+
12+
class MockInputBatch:
13+
def __init__(self, req_ids, num_computed_tokens_cpu):
14+
self.req_ids = req_ids
15+
self.num_computed_tokens_cpu = num_computed_tokens_cpu
16+
17+
def swap_states(self, i, j):
18+
self.req_ids[i], self.req_ids[j] = self.req_ids[j], self.req_ids[i]
19+
self.num_computed_tokens_cpu[i], self.num_computed_tokens_cpu[j] = (
20+
self.num_computed_tokens_cpu[j],
21+
self.num_computed_tokens_cpu[i],
22+
)
23+
24+
25+
class MockSchedulerOutput:
26+
def __init__(self, num_scheduled_tokens):
27+
self.num_scheduled_tokens = num_scheduled_tokens
28+
29+
30+
@dataclass
31+
class ReorderTestCase:
32+
requests: list[tuple[int, int]] # (num_scheduled_tokens, num_computed_tokens)
33+
expected_order: list[int]
34+
expected_modified: bool
35+
decode_threshold: int = 1
36+
37+
38+
# Test cases for batch reordering
39+
REORDER_TEST_CASES = {
40+
"all_decodes": ReorderTestCase(
41+
requests=[(1, 10), (1, 20), (1, 30)],
42+
expected_order=[0, 1, 2],
43+
expected_modified=False,
44+
),
45+
"all_prefills": ReorderTestCase(
46+
requests=[(100, 100), (200, 200), (300, 300)],
47+
expected_order=[0, 1, 2],
48+
expected_modified=False,
49+
),
50+
"mixed_interleaved": ReorderTestCase(
51+
requests=[(100, 100), (1, 10), (200, 200), (1, 20)],
52+
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
53+
expected_modified=True,
54+
),
55+
"already_ordered": ReorderTestCase(
56+
requests=[(1, 10), (1, 20), (100, 100), (200, 200)],
57+
expected_order=[0, 1, 2, 3],
58+
expected_modified=False,
59+
),
60+
"single_request": ReorderTestCase(
61+
requests=[(1, 10)],
62+
expected_order=[0],
63+
expected_modified=False,
64+
),
65+
"higher_threshold": ReorderTestCase(
66+
requests=[(2, 10), (3, 20), (5, 30), (6, 40)],
67+
expected_order=[0, 1, 2, 3],
68+
expected_modified=False,
69+
decode_threshold=4,
70+
),
71+
"decodes_at_end": ReorderTestCase(
72+
requests=[(100, 100), (200, 200), (1, 10), (1, 20)],
73+
expected_order=[2, 3, 0, 1],
74+
expected_modified=True,
75+
),
76+
"decode_extend_prefill": ReorderTestCase(
77+
requests=[(100, 100), (10, 50), (1, 10)],
78+
expected_order=[2, 1, 0],
79+
expected_modified=True,
80+
),
81+
"extend_prefill_only": ReorderTestCase(
82+
requests=[(100, 100), (10, 50), (200, 200), (20, 75)],
83+
expected_order=[3, 1, 2, 0], # Only swap 0↔3, keep 1 and 2 in place
84+
expected_modified=True,
85+
),
86+
}
87+
88+
89+
@pytest.mark.parametrize(
90+
"test_case", REORDER_TEST_CASES.values(), ids=REORDER_TEST_CASES.keys()
91+
)
92+
def test_reorder_batch_to_split_decodes_and_prefills(test_case: ReorderTestCase):
93+
req_ids = [f"r{i}" for i in range(len(test_case.requests))]
94+
num_computed_tokens = np.array([r[1] for r in test_case.requests], dtype=np.int32)
95+
num_scheduled_tokens = {f"r{i}": r[0] for i, r in enumerate(test_case.requests)}
96+
97+
input_batch = MockInputBatch(req_ids, num_computed_tokens)
98+
scheduler_output = MockSchedulerOutput(num_scheduled_tokens)
99+
100+
modified = reorder_batch_to_split_decodes_and_prefills(
101+
input_batch, scheduler_output, decode_threshold=test_case.decode_threshold
102+
)
103+
104+
expected_req_ids = [f"r{i}" for i in test_case.expected_order]
105+
106+
assert modified == test_case.expected_modified, (
107+
f"Expected modified={test_case.expected_modified}, got {modified}"
108+
)
109+
assert input_batch.req_ids == expected_req_ids, (
110+
f"Expected order {expected_req_ids}, got {input_batch.req_ids}"
111+
)

vllm/v1/attention/backends/utils.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -795,51 +795,59 @@ def reorder_batch_to_split_decodes_and_prefills(
795795
Returns:
796796
True if the batch was modified, False otherwise.
797797
"""
798-
# We now want to reorder the batch so that the "decode" requests are at
799-
# the front and the "prefill" requests are at the back using the least
800-
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
801-
# requests where attention is likely memory-bound and "prefill" to mean
802-
# requests where attention is likely compute-bound, TODO(lucas): figure out
803-
# a better naming here)
804-
decodes = []
805-
prefills = []
806-
num_decode_tokens = 0
807-
num_prefill_tokens = 0
808-
809-
for i, req_id in enumerate(input_batch.req_ids):
810-
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
811-
if num_tokens <= decode_threshold:
812-
decodes.append(i)
813-
num_decode_tokens += num_tokens
814-
else:
815-
prefills.append(i)
816-
num_prefill_tokens += num_tokens
817-
818-
# We hope that this is fairly minimal since decodes
819-
# should be around for a number of iterations so hopefully they are
820-
# relatively stationary (and new request are generally appended to the
821-
# persistent batch so already should be at the back)
822-
# To achieve this we loop over the decodes in descending order and
823-
# the prefills in ascending order. We swap decodes from the "back"
824-
# i.e. past where the last decode should be in the reodorered with
825-
# prefills from the front of the batch.
826-
# `decodes` and `prefills` are already in ascending order just based on
827-
# the above loop
828-
num_decodes = len(decodes)
829-
num_prefills = len(prefills)
830-
modified_batch = False
831-
832-
for i in range(1, min(num_decodes, num_prefills) + 1):
833-
# If the decode is at the "back" of the batch, i, we can swap it
834-
# with the prefill closest to the front of the batch
835-
decode_idx = decodes[num_decodes - i]
836-
if decode_idx < num_decodes:
837-
break
838-
839-
input_batch.swap_states(prefills[i - 1], decode_idx)
840-
modified_batch = True
841-
842-
return modified_batch
798+
# We now want to reorder the batch into decode → extend → prefill order
799+
# where:
800+
# decode: request with num_scheduled_tokens <= decode_threshold
801+
# extend: non-decode request with existing context
802+
# prefill: non-decode request with no existing context
803+
# NOTE for now we loosely use "decode" to mean requests where attention is
804+
# likely memory-bound and "prefill" to mean requests where attention is
805+
# likely compute-bound,
806+
num_reqs = len(input_batch.req_ids)
807+
num_scheduled_tokens = [
808+
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
809+
]
810+
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
811+
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
812+
813+
is_decode = num_scheduled_tokens_np <= decode_threshold
814+
is_extend = (~is_decode) & (num_computed_tokens_np > num_scheduled_tokens_np)
815+
is_prefill = (~is_decode) & (num_computed_tokens_np == num_scheduled_tokens_np)
816+
817+
# Desired order: decode → extend → prefill
818+
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
819+
req_regions[is_extend] = 1
820+
req_regions[is_prefill] = 2
821+
822+
num_decodes = int(is_decode.sum())
823+
num_extends = int(is_extend.sum())
824+
825+
target_regions = np.zeros(num_reqs, dtype=np.int32)
826+
target_regions[num_decodes : num_decodes + num_extends] = 1
827+
target_regions[num_decodes + num_extends :] = 2
828+
829+
needs_swap = req_regions != target_regions
830+
831+
if not needs_swap.any():
832+
return False
833+
834+
# Extract indices that need swapping and sort by target region
835+
swap_indices = np.where(needs_swap)[0]
836+
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
837+
dest_indices = swap_indices[sorted_order]
838+
839+
src_dest_map = {int(src): int(dst) for src, dst in zip(swap_indices, dest_indices)}
840+
841+
for src in src_dest_map:
842+
dst = src_dest_map[src]
843+
while src != dst:
844+
input_batch.swap_states(src, dst)
845+
# Mark dst as done by updating its destination to itself
846+
next_dst = src_dest_map.get(dst, dst)
847+
src_dest_map[dst] = dst
848+
dst = next_dst
849+
850+
return True
843851

844852

845853
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:

0 commit comments

Comments
 (0)