Skip to content

Commit b9ff4f2

Browse files
jiangkuaixue123root
andauthored
[feature] extend DBO to XBO (#30120)
Signed-off-by: jiangkuaixue123 <[email protected]> Co-authored-by: root <[email protected]>
1 parent c881db3 commit b9ff4f2

File tree

10 files changed

+133
-73
lines changed

10 files changed

+133
-73
lines changed

tests/v1/attention/test_attention_splitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def test_prefill_split_across_ubatches(
323323
num_tokens,
324324
batch_spec.batch_size,
325325
split_point=split_point,
326+
num_ubatches=2,
326327
)
327328
assert ubatch_slices is not None and len(ubatch_slices) == 2
328329

vllm/config/parallel.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ class ParallelConfig:
156156

157157
enable_dbo: bool = False
158158
"""Enable dual batch overlap for the model executor."""
159+
ubatch_size: int = 0
160+
"""Number of ubatch size."""
159161

160162
dbo_decode_token_threshold: int = 32
161163
"""The threshold for dual batch overlap for batches only containing decodes.
@@ -325,6 +327,14 @@ def world_size_across_dp(self) -> int:
325327
including data parallelism."""
326328
return self.world_size * self.data_parallel_size
327329

330+
@property
331+
def use_ubatching(self) -> bool:
332+
return self.enable_dbo or self.ubatch_size > 1
333+
334+
@property
335+
def num_ubatches(self) -> int:
336+
return 2 if self.enable_dbo else self.ubatch_size
337+
328338
def get_next_dp_init_port(self) -> int:
329339
"""
330340
We might need to initialize process groups in multiple

vllm/config/vllm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -870,9 +870,12 @@ def has_blocked_weights():
870870
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
871871
)
872872

873-
if self.parallel_config.enable_dbo:
873+
if self.parallel_config.use_ubatching:
874874
a2a_backend = self.parallel_config.all2all_backend
875-
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
875+
assert a2a_backend in [
876+
"deepep_low_latency",
877+
"deepep_high_throughput",
878+
], (
876879
"Microbatching currently only supports the deepep_low_latency and "
877880
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
878881
"supported. To fix use --all2all-backend=deepep_low_latency or "

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ class EngineArgs:
408408
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
409409
all2all_backend: str | None = ParallelConfig.all2all_backend
410410
enable_dbo: bool = ParallelConfig.enable_dbo
411+
ubatch_size: int = ParallelConfig.ubatch_size
411412
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
412413
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
413414
disable_nccl_for_dp_synchronization: bool = (
@@ -841,6 +842,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
841842
"--all2all-backend", **parallel_kwargs["all2all_backend"]
842843
)
843844
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
845+
parallel_group.add_argument(
846+
"--ubatch-size",
847+
**parallel_kwargs["ubatch_size"],
848+
)
844849
parallel_group.add_argument(
845850
"--dbo-decode-token-threshold",
846851
**parallel_kwargs["dbo_decode_token_threshold"],
@@ -1557,6 +1562,7 @@ def create_engine_config(
15571562
enable_expert_parallel=self.enable_expert_parallel,
15581563
all2all_backend=self.all2all_backend,
15591564
enable_dbo=self.enable_dbo,
1565+
ubatch_size=self.ubatch_size,
15601566
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
15611567
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
15621568
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,

vllm/v1/attention/backends/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,11 @@ def _make_metadata_with_slice(
201201
)
202202
# NOTE: last token can be outside of the last request if we have CG padding.
203203

204-
# If the "middle" request has tokens in both ubatches, we have to split it.
205-
# If ubatch_slice is the first ubatch then we will be splitting the last
206-
# request. If it's the second microbatch, then we will be splitting the
207-
# first request
204+
# If the request is split across ubatches, we have to adjust the metadata.
205+
# splits_first_request: The first request in this slice is the continuation of
206+
# a request that started in a previous slice.
207+
# splits_last_request: The last request in this slice continues into the
208+
# next slice.
208209
splits_first_request = first_tok > start_locs[first_req]
209210
splits_last_request = last_tok < start_locs[last_req + 1] - 1
210211

@@ -225,7 +226,10 @@ def _make_metadata_with_slice(
225226
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
226227

227228
if splits_last_request:
228-
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop
229+
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
230+
# the tokens skipped because query_start_loc_cpu might have been modified
231+
# if splits_first_request is True.
232+
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
229233
query_start_loc[-1] -= tokens_skipped
230234
query_start_loc_cpu[-1] -= tokens_skipped
231235

vllm/v1/worker/dp_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.logger import init_logger
1212
from vllm.v1.worker.ubatch_utils import (
1313
check_ubatch_thresholds,
14-
is_second_ubatch_empty,
14+
is_last_ubatch_empty,
1515
)
1616

1717
logger = init_logger(__name__)
@@ -56,7 +56,7 @@ def _run_ar(
5656
return tensor
5757

5858

59-
def _post_process_ubatch(tensor: torch.Tensor) -> bool:
59+
def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
6060
orig_num_tokens_tensor = tensor[0, :]
6161
padded_num_tokens_tensor = tensor[1, :]
6262

@@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool:
6868
# there are no "empty" second ubatches
6969
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
7070
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
71-
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens):
71+
if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
7272
logger.debug(
7373
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
7474
)
@@ -146,7 +146,7 @@ def _synchronize_dp_ranks(
146146
assert should_attempt_dp_padding == should_dp_pad
147147

148148
# Check conditions for microbatching
149-
should_ubatch = _post_process_ubatch(tensor)
149+
should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
150150

151151
if should_ubatch and not should_dp_pad:
152152
logger.debug_once(

vllm/v1/worker/gpu_model_runner.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2987,7 +2987,7 @@ def execute_model(
29872987

29882988
cascade_attn_prefix_lens = None
29892989
# Disable cascade attention when using microbatching (DBO)
2990-
if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
2990+
if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
29912991
# Pre-compute cascade attention prefix lengths
29922992
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
29932993
num_scheduled_tokens_np,
@@ -3028,6 +3028,13 @@ def execute_model(
30283028
num_scheduled_tokens_np,
30293029
num_tokens_padded,
30303030
num_reqs_padded,
3031+
self.parallel_config.num_ubatches,
3032+
)
3033+
3034+
logger.debug(
3035+
"ubatch_slices: %s, ubatch_slices_padded: %s",
3036+
ubatch_slices,
3037+
ubatch_slices_padded,
30313038
)
30323039

30333040
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
@@ -3710,11 +3717,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
37103717
# wrap the model with full cudagraph wrapper if needed.
37113718
cudagraph_mode = self.compilation_config.cudagraph_mode
37123719
assert cudagraph_mode is not None
3713-
if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
3720+
if (
3721+
cudagraph_mode.has_full_cudagraphs()
3722+
and not self.parallel_config.use_ubatching
3723+
):
37143724
self.model = CUDAGraphWrapper(
37153725
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
37163726
)
3717-
elif self.parallel_config.enable_dbo:
3727+
elif self.parallel_config.use_ubatching:
37183728
if cudagraph_mode.has_full_cudagraphs():
37193729
self.model = UBatchWrapper(
37203730
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
@@ -4095,7 +4105,16 @@ def _dummy_run(
40954105
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
40964106
)
40974107
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
4098-
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
4108+
should_ubatch,
4109+
num_scheduled_tokens,
4110+
num_tokens_padded,
4111+
num_reqs_padded,
4112+
self.vllm_config.parallel_config.num_ubatches,
4113+
)
4114+
logger.debug(
4115+
"ubatch_slices: %s, ubatch_slices_padded: %s",
4116+
ubatch_slices,
4117+
ubatch_slices_padded,
40994118
)
41004119

41014120
attn_metadata: PerLayerAttnMetadata | None = None
@@ -4644,7 +4663,7 @@ def _capture_cudagraphs(
46444663
# is above the threshold. Otherwise we just capture a non-ubatched
46454664
# version of the graph
46464665
allow_microbatching = (
4647-
self.parallel_config.enable_dbo
4666+
self.parallel_config.use_ubatching
46484667
and cudagraph_runtime_mode == CUDAGraphMode.FULL
46494668
and uniform_decode
46504669
and check_ubatch_thresholds(
@@ -4779,8 +4798,8 @@ def initialize_metadata_builders(
47794798
if kv_cache_group_id < len(kernel_block_sizes)
47804799
else None,
47814800
num_metadata_builders=1
4782-
if not self.parallel_config.enable_dbo
4783-
else 2,
4801+
if not self.parallel_config.use_ubatching
4802+
else self.parallel_config.num_ubatches,
47844803
)
47854804
# Calculate reorder batch threshold (if needed)
47864805
# Note (tdoublep): do this *after* constructing builders,

vllm/v1/worker/gpu_ubatch_wrapper.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def __init__(
103103
self.vllm_config = vllm_config
104104
self.compilation_config = vllm_config.compilation_config
105105
self.comm_stream = torch.cuda.Stream(device=device)
106-
# Two ubatch threads plus the main thread
107-
self.ready_barrier = threading.Barrier(3)
106+
# Ubatch threads plus the main thread
107+
self.ready_barrier = threading.Barrier(
108+
self.vllm_config.parallel_config.num_ubatches + 1
109+
)
108110

109111
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
110112

@@ -309,7 +311,7 @@ def _make_ubatch_metadata(
309311
create_forward_context(
310312
attn_metadata[i] if attn_metadata is not None else None,
311313
self.vllm_config,
312-
dp_metadata=dp_metadata,
314+
dp_metadata=dp_metadata[i],
313315
batch_descriptor=batch_descriptor,
314316
cudagraph_runtime_mode=cudagraph_runtime_mode,
315317
)
@@ -417,18 +419,19 @@ def __call__(self, *args, **kwargs):
417419

418420
# We shouldn't be here unless we are running with multiple DP ranks
419421
assert dp_metadata is not None
420-
num_tokens_per_ubatch = (
421-
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
422-
)
423-
dp_size = self.vllm_config.parallel_config.data_parallel_size
424-
ubatch_num_tokens_across_dp = torch.tensor(
425-
[num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32
426-
)
427-
ubatch_dp_metadata = DPMetadata.make(
428-
self.vllm_config.parallel_config,
429-
num_tokens_per_ubatch,
430-
ubatch_num_tokens_across_dp,
431-
)
422+
ubatch_dp_metadata = []
423+
for ubatch_slice in ubatch_slices:
424+
dp_size = self.vllm_config.parallel_config.data_parallel_size
425+
ubatch_num_tokens_across_dp = torch.tensor(
426+
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
427+
)
428+
ubatch_dp_metadata.append(
429+
DPMetadata.make(
430+
self.vllm_config.parallel_config,
431+
ubatch_slice.num_tokens,
432+
ubatch_num_tokens_across_dp,
433+
)
434+
)
432435

433436
if (
434437
num_tokens not in self.cudagraphs
@@ -464,7 +467,7 @@ def __call__(self, *args, **kwargs):
464467
intermediate_tensors=intermediate_tensors,
465468
inputs_embeds=inputs_embeds,
466469
compute_stream=compute_stream,
467-
dp_metadata=dp_metadata,
470+
dp_metadata=ubatch_dp_metadata,
468471
batch_descriptor=batch_descriptor,
469472
cudagraph_runtime_mode=CUDAGraphMode.NONE,
470473
)

vllm/v1/worker/ubatch_utils.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,34 @@ def num_tokens(self) -> int:
2727
UBatchSlices: TypeAlias = list[UBatchSlice]
2828

2929

30-
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool:
31-
return (padded_num_tokens // 2) >= orig_num_tokens
30+
def is_last_ubatch_empty(
31+
orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int
32+
) -> bool:
33+
return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens
3234

3335

3436
def check_ubatch_thresholds(
3537
config: ParallelConfig, num_tokens: int, uniform_decode: bool
3638
) -> bool:
37-
if not config.enable_dbo:
39+
if not config.use_ubatching:
3840
return False
3941
if uniform_decode:
4042
return num_tokens >= config.dbo_decode_token_threshold
4143
else:
4244
return num_tokens >= config.dbo_prefill_token_threshold
4345

4446

45-
# This just pads the second ubatch slice out to the total number of tokens
47+
# This pads the last ubatch slice out to the total number of tokens
4648
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
4749
def _pad_out_ubatch_slices(
4850
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
4951
) -> UBatchSlices:
50-
# TODO(lucas): handle empty second ubatch
51-
padded_second_request_slice = slice(
52-
ubatch_slices[1].request_slice.start, num_reqs_padded
53-
)
54-
padded_second_token_slice = slice(
55-
ubatch_slices[1].token_slice.start, num_total_tokens
56-
)
57-
return [
58-
ubatch_slices[0],
59-
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
52+
last_slice = ubatch_slices[-1]
53+
padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded)
54+
padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens)
55+
56+
return ubatch_slices[:-1] + [
57+
UBatchSlice(padded_last_request_slice, padded_last_token_slice)
6058
]
6159

6260

@@ -65,40 +63,45 @@ def maybe_create_ubatch_slices(
6563
num_scheduled_tokens: np.ndarray,
6664
num_tokens_padded: int,
6765
num_reqs_padded: int,
68-
split_point: int | None = None,
66+
num_ubatches: int,
67+
split_point: list[int] | int | None = None,
6968
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
7069
if not should_ubatch:
7170
return None, None
7271

7372
if split_point is None:
74-
split_point = int(num_tokens_padded) // 2
73+
split_point = int(num_tokens_padded) // num_ubatches
74+
75+
token_split_points = [split_point * i for i in range(1, num_ubatches)]
7576

7677
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
7778
# in cu_num_tokens directly (i.e. query_start_loc)
7879
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
7980
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
8081

81-
first_ubatch_token_slice = slice(0, split_point)
82-
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1])
82+
ubatch_slices = []
83+
start_token = 0
8384

84-
# Determine request slices using exclusive stop semantics
85-
# First ubatch includes requests whose tokens overlap [0, split_point)
86-
first_ubatch_req_stop = int(
87-
np.searchsorted(cu_num_tokens, split_point, side="left")
88-
)
89-
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
85+
# Add the end point to the split points to make iteration easier
86+
all_points = token_split_points + [cu_num_tokens[-1]]
9087

91-
# Second ubatch starts at the request that contains the split_point
92-
# or the request starting exactly at split_point (if on boundary)
93-
second_ubatch_req_start = int(
94-
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
95-
)
96-
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
88+
for end_token in all_points:
89+
token_slice = slice(start_token, end_token)
9790

98-
ubatch_slices = [
99-
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
100-
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
101-
]
91+
# Determine request slices using exclusive stop semantics
92+
# Ubatch includes requests whose tokens overlap [start_token, end_token)
93+
94+
# Start at the request that contains the start_token
95+
# or the request starting exactly at start_token (if on boundary)
96+
req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)
97+
98+
# Stop at the request that starts at or after end_token
99+
req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))
100+
101+
req_slice = slice(req_start, req_stop)
102+
ubatch_slices.append(UBatchSlice(req_slice, token_slice))
103+
104+
start_token = end_token
102105

103106
ubatch_slices_padded = _pad_out_ubatch_slices(
104107
ubatch_slices, num_tokens_padded, num_reqs_padded

0 commit comments

Comments
 (0)