Skip to content

Commit 11ea5ec

Browse files
authored
[Model Runner V2] Refactor CudaGraphManager (vllm-project#29583)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent ecb1952 commit 11ea5ec

File tree

1 file changed

+154
-95
lines changed

1 file changed

+154
-95
lines changed
Lines changed: 154 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from unittest.mock import patch
3+
from collections.abc import Callable, Iterable
4+
from typing import Any
45

56
import numpy as np
67
import torch
@@ -32,6 +33,7 @@ def __init__(
3233

3334
self.max_model_len = vllm_config.model_config.max_model_len
3435
self.max_num_reqs = self.scheduler_config.max_num_seqs
36+
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
3537
self.dp_size = vllm_config.parallel_config.data_parallel_size
3638
self.compilation_config = vllm_config.compilation_config
3739
assert self.compilation_config is not None
@@ -40,102 +42,60 @@ def __init__(
4042
self.cudagraph_mode = CUDAGraphMode.NONE
4143
else:
4244
self.cudagraph_mode = self.compilation_config.cudagraph_mode
43-
if self.compilation_config.cudagraph_capture_sizes is not None:
44-
cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
45-
# Limit the cudagraph sizes to the max decode batch size.
46-
self.cudagraph_sizes = [
47-
x for x in cudagraph_sizes if x <= self.max_num_reqs
48-
]
49-
else:
50-
self.cudagraph_sizes = []
51-
self.padded_sizes = self._init_padded_sizes()
45+
self.cudagraph_sizes = get_cudagraph_sizes(
46+
self.compilation_config.cudagraph_capture_sizes,
47+
self.max_num_reqs,
48+
self.max_num_tokens,
49+
self.cudagraph_mode,
50+
)
5251

5352
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
5453
self.pool = torch.cuda.graph_pool_handle()
5554
self.hidden_states: torch.Tensor | None = None
5655

57-
def _init_padded_sizes(self) -> dict[int, int]:
58-
if not self.cudagraph_mode.has_full_cudagraphs():
59-
# Full cuda graphs are not used.
60-
return {}
61-
if not self.cudagraph_sizes:
62-
return {}
63-
64-
padded_sizes: dict[int, int] = {}
65-
for i in range(1, self.cudagraph_sizes[-1] + 1):
66-
for x in self.cudagraph_sizes:
67-
if i <= x:
68-
padded_sizes[i] = x
69-
break
70-
return padded_sizes
71-
7256
def needs_capture(self) -> bool:
73-
return len(self.padded_sizes) > 0
57+
return len(self.cudagraph_sizes) > 0
7458

7559
def get_cudagraph_size(
7660
self,
7761
scheduler_output: SchedulerOutput,
7862
num_tokens_after_padding: int,
7963
) -> int | None:
80-
if not self.cudagraph_mode.has_full_cudagraphs():
81-
return None
82-
if self.cudagraph_mode != CUDAGraphMode.FULL:
83-
# TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
84-
all_decode = all(
85-
x == 1 for x in scheduler_output.num_scheduled_tokens.values()
86-
)
87-
if not all_decode:
88-
# Prefill is included.
89-
return None
90-
return self.padded_sizes.get(num_tokens_after_padding)
64+
return get_cudagraph_size(
65+
num_tokens_after_padding,
66+
scheduler_output.num_scheduled_tokens.values(),
67+
self.cudagraph_sizes,
68+
self.cudagraph_mode,
69+
)
9170

9271
def capture_graph(
9372
self,
94-
batch_size: int,
73+
num_tokens: int,
9574
model: nn.Module,
9675
input_buffers: InputBuffers,
9776
block_tables: BlockTables,
9877
attn_metadata_builders: list[AttentionMetadataBuilder],
9978
kv_cache_config: KVCacheConfig,
10079
) -> None:
101-
assert batch_size not in self.graphs
102-
103-
# Prepare dummy inputs.
104-
input_ids = input_buffers.input_ids.gpu[:batch_size]
105-
positions = input_buffers.positions[:batch_size]
106-
107-
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
108-
input_buffers.query_start_loc.np[batch_size:] = batch_size
109-
input_buffers.query_start_loc.copy_to_gpu()
110-
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
111-
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
112-
# seq_lens_np (CPU), which might cause issues in some attention backends.
113-
input_buffers.seq_lens[:batch_size] = 1
114-
input_buffers.seq_lens[batch_size:] = 0
115-
116-
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
117-
slot_mappings = block_tables.slot_mappings[:, :batch_size]
118-
119-
attn_metadata = build_attn_metadata(
120-
attn_metadata_builders=attn_metadata_builders,
121-
num_reqs=batch_size,
122-
num_tokens=batch_size,
123-
query_start_loc_gpu=input_buffers.query_start_loc.gpu[: batch_size + 1],
124-
query_start_loc_cpu=input_buffers.query_start_loc.cpu[: batch_size + 1],
125-
seq_lens=input_buffers.seq_lens,
126-
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
127-
num_computed_tokens_cpu=None, # FIXME
128-
block_tables=input_block_tables,
129-
slot_mappings=slot_mappings,
130-
kv_cache_config=kv_cache_config,
80+
num_reqs = min(num_tokens, self.max_num_reqs)
81+
input_ids = input_buffers.input_ids.gpu[:num_tokens]
82+
positions = input_buffers.positions[:num_tokens]
83+
attn_metadata = prepare_inputs_to_capture(
84+
num_reqs,
85+
num_tokens,
86+
input_buffers,
87+
block_tables,
88+
attn_metadata_builders,
89+
self.max_model_len,
90+
kv_cache_config,
13191
)
132-
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, batch_size)
92+
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
13393

13494
# Warm up.
13595
with set_forward_context(
13696
attn_metadata,
13797
self.vllm_config,
138-
num_tokens=batch_size,
98+
num_tokens=num_tokens,
13999
cudagraph_runtime_mode=CUDAGraphMode.NONE,
140100
num_tokens_across_dp=num_tokens_across_dp,
141101
):
@@ -147,13 +107,13 @@ def capture_graph(
147107
self.hidden_states = torch.empty_like(hidden_states)
148108

149109
# Capture the graph.
110+
assert num_tokens not in self.graphs
150111
graph = torch.cuda.CUDAGraph()
151112
with (
152-
patch("torch.cuda.empty_cache", lambda: None),
153113
set_forward_context(
154114
attn_metadata,
155115
self.vllm_config,
156-
num_tokens=batch_size,
116+
num_tokens=num_tokens,
157117
cudagraph_runtime_mode=CUDAGraphMode.NONE,
158118
num_tokens_across_dp=num_tokens_across_dp,
159119
),
@@ -163,8 +123,8 @@ def capture_graph(
163123
input_ids=input_ids,
164124
positions=positions,
165125
)
166-
self.hidden_states[:batch_size] = hidden_states
167-
self.graphs[batch_size] = graph
126+
self.hidden_states[:num_tokens] = hidden_states
127+
self.graphs[num_tokens] = graph
168128

169129
@torch.inference_mode()
170130
def capture(
@@ -175,25 +135,124 @@ def capture(
175135
attn_metadata_builders: list[AttentionMetadataBuilder],
176136
kv_cache_config: KVCacheConfig,
177137
) -> None:
178-
assert self.needs_capture()
179-
# Capture larger graphs first.
180-
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
181-
if is_global_first_rank():
182-
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
183-
184-
with graph_capture(device=self.device):
185-
for batch_size in sizes_to_capture:
186-
self.capture_graph(
187-
batch_size,
188-
model,
189-
input_buffers,
190-
block_tables,
191-
attn_metadata_builders,
192-
kv_cache_config,
193-
)
194-
195-
def run(self, batch_size: int) -> torch.Tensor:
196-
assert batch_size in self.graphs
197-
self.graphs[batch_size].replay()
138+
capture_graphs(
139+
self.cudagraph_sizes,
140+
self.device,
141+
self.capture_graph,
142+
model=model,
143+
input_buffers=input_buffers,
144+
block_tables=block_tables,
145+
attn_metadata_builders=attn_metadata_builders,
146+
kv_cache_config=kv_cache_config,
147+
)
148+
149+
def run(self, num_tokens: int) -> torch.Tensor:
150+
assert num_tokens in self.graphs
151+
self.graphs[num_tokens].replay()
198152
assert self.hidden_states is not None
199-
return self.hidden_states[:batch_size]
153+
return self.hidden_states[:num_tokens]
154+
155+
156+
def get_cudagraph_sizes(
157+
capture_sizes: list[int] | None,
158+
max_num_reqs: int,
159+
max_num_tokens: int,
160+
cudagraph_mode: CUDAGraphMode,
161+
) -> dict[int, int]:
162+
if not cudagraph_mode.has_full_cudagraphs():
163+
return {}
164+
if not capture_sizes:
165+
return {}
166+
167+
capture_sizes = sorted(capture_sizes)
168+
# Limit the capture sizes to the max number of requests or tokens.
169+
upper_bound = (
170+
max_num_reqs
171+
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
172+
else max_num_tokens
173+
)
174+
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
175+
if not capture_sizes:
176+
return {}
177+
178+
cudagraph_sizes: dict[int, int] = {}
179+
for i in range(1, capture_sizes[-1] + 1):
180+
for x in capture_sizes:
181+
if i <= x:
182+
cudagraph_sizes[i] = x
183+
break
184+
return cudagraph_sizes
185+
186+
187+
def get_cudagraph_size(
188+
num_tokens_after_dp_padding: int,
189+
num_tokens_per_request: Iterable[int],
190+
cudagraph_sizes: dict[int, int],
191+
cudagraph_mode: CUDAGraphMode,
192+
) -> int | None:
193+
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
194+
if size is None:
195+
# No CUDA graph for this size.
196+
return None
197+
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
198+
all_decode = all(x == 1 for x in num_tokens_per_request)
199+
if not all_decode:
200+
# Prefill is included.
201+
return None
202+
return size
203+
204+
205+
def capture_graphs(
206+
cudagraph_sizes: dict[int, int],
207+
device: torch.device,
208+
capture_fn: Callable,
209+
**capture_kwargs,
210+
) -> None:
211+
# Capture larger graphs first.
212+
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
213+
if is_global_first_rank():
214+
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
215+
216+
with graph_capture(device=device):
217+
for size in sizes_to_capture:
218+
capture_fn(size, **capture_kwargs)
219+
220+
221+
def prepare_inputs_to_capture(
222+
num_reqs: int,
223+
num_tokens: int,
224+
input_buffers: InputBuffers,
225+
block_tables: BlockTables,
226+
attn_metadata_builders: list[AttentionMetadataBuilder],
227+
max_model_len: int,
228+
kv_cache_config: KVCacheConfig,
229+
) -> dict[str, Any]:
230+
num_tokens_per_req = num_tokens // num_reqs
231+
query_start_loc = input_buffers.query_start_loc
232+
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
233+
query_start_loc.np[num_reqs:] = num_tokens
234+
query_start_loc.copy_to_gpu()
235+
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
236+
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
237+
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
238+
# seq_lens_np (CPU), which might cause issues in some attention backends.
239+
input_buffers.seq_lens[:num_reqs] = 1
240+
input_buffers.seq_lens[num_reqs:] = 0
241+
242+
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
243+
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
244+
245+
attn_metadata = build_attn_metadata(
246+
attn_metadata_builders=attn_metadata_builders,
247+
num_reqs=num_reqs,
248+
num_tokens=num_tokens,
249+
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
250+
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
251+
seq_lens=input_buffers.seq_lens,
252+
seq_lens_np=seq_lens_np,
253+
num_computed_tokens_cpu=None, # FIXME
254+
block_tables=input_block_tables,
255+
slot_mappings=slot_mappings,
256+
kv_cache_config=kv_cache_config,
257+
)
258+
return attn_metadata

0 commit comments

Comments
 (0)