11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3- from unittest .mock import patch
3+ from collections .abc import Callable , Iterable
4+ from typing import Any
45
56import numpy as np
67import torch
@@ -32,6 +33,7 @@ def __init__(
3233
3334 self .max_model_len = vllm_config .model_config .max_model_len
3435 self .max_num_reqs = self .scheduler_config .max_num_seqs
36+ self .max_num_tokens = self .scheduler_config .max_num_batched_tokens
3537 self .dp_size = vllm_config .parallel_config .data_parallel_size
3638 self .compilation_config = vllm_config .compilation_config
3739 assert self .compilation_config is not None
@@ -40,102 +42,60 @@ def __init__(
4042 self .cudagraph_mode = CUDAGraphMode .NONE
4143 else :
4244 self .cudagraph_mode = self .compilation_config .cudagraph_mode
43- if self .compilation_config .cudagraph_capture_sizes is not None :
44- cudagraph_sizes = sorted (self .compilation_config .cudagraph_capture_sizes )
45- # Limit the cudagraph sizes to the max decode batch size.
46- self .cudagraph_sizes = [
47- x for x in cudagraph_sizes if x <= self .max_num_reqs
48- ]
49- else :
50- self .cudagraph_sizes = []
51- self .padded_sizes = self ._init_padded_sizes ()
45+ self .cudagraph_sizes = get_cudagraph_sizes (
46+ self .compilation_config .cudagraph_capture_sizes ,
47+ self .max_num_reqs ,
48+ self .max_num_tokens ,
49+ self .cudagraph_mode ,
50+ )
5251
5352 self .graphs : dict [int , torch .cuda .CUDAGraph ] = {}
5453 self .pool = torch .cuda .graph_pool_handle ()
5554 self .hidden_states : torch .Tensor | None = None
5655
57- def _init_padded_sizes (self ) -> dict [int , int ]:
58- if not self .cudagraph_mode .has_full_cudagraphs ():
59- # Full cuda graphs are not used.
60- return {}
61- if not self .cudagraph_sizes :
62- return {}
63-
64- padded_sizes : dict [int , int ] = {}
65- for i in range (1 , self .cudagraph_sizes [- 1 ] + 1 ):
66- for x in self .cudagraph_sizes :
67- if i <= x :
68- padded_sizes [i ] = x
69- break
70- return padded_sizes
71-
7256 def needs_capture (self ) -> bool :
73- return len (self .padded_sizes ) > 0
57+ return len (self .cudagraph_sizes ) > 0
7458
7559 def get_cudagraph_size (
7660 self ,
7761 scheduler_output : SchedulerOutput ,
7862 num_tokens_after_padding : int ,
7963 ) -> int | None :
80- if not self .cudagraph_mode .has_full_cudagraphs ():
81- return None
82- if self .cudagraph_mode != CUDAGraphMode .FULL :
83- # TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
84- all_decode = all (
85- x == 1 for x in scheduler_output .num_scheduled_tokens .values ()
86- )
87- if not all_decode :
88- # Prefill is included.
89- return None
90- return self .padded_sizes .get (num_tokens_after_padding )
64+ return get_cudagraph_size (
65+ num_tokens_after_padding ,
66+ scheduler_output .num_scheduled_tokens .values (),
67+ self .cudagraph_sizes ,
68+ self .cudagraph_mode ,
69+ )
9170
9271 def capture_graph (
9372 self ,
94- batch_size : int ,
73+ num_tokens : int ,
9574 model : nn .Module ,
9675 input_buffers : InputBuffers ,
9776 block_tables : BlockTables ,
9877 attn_metadata_builders : list [AttentionMetadataBuilder ],
9978 kv_cache_config : KVCacheConfig ,
10079 ) -> None :
101- assert batch_size not in self .graphs
102-
103- # Prepare dummy inputs.
104- input_ids = input_buffers .input_ids .gpu [:batch_size ]
105- positions = input_buffers .positions [:batch_size ]
106-
107- input_buffers .query_start_loc .np [: batch_size + 1 ] = np .arange (batch_size + 1 )
108- input_buffers .query_start_loc .np [batch_size :] = batch_size
109- input_buffers .query_start_loc .copy_to_gpu ()
110- # HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
111- # for seq_lens. This leads to a mismatch between seq_lens (GPU) and
112- # seq_lens_np (CPU), which might cause issues in some attention backends.
113- input_buffers .seq_lens [:batch_size ] = 1
114- input_buffers .seq_lens [batch_size :] = 0
115-
116- input_block_tables = [x [:batch_size ] for x in block_tables .input_block_tables ]
117- slot_mappings = block_tables .slot_mappings [:, :batch_size ]
118-
119- attn_metadata = build_attn_metadata (
120- attn_metadata_builders = attn_metadata_builders ,
121- num_reqs = batch_size ,
122- num_tokens = batch_size ,
123- query_start_loc_gpu = input_buffers .query_start_loc .gpu [: batch_size + 1 ],
124- query_start_loc_cpu = input_buffers .query_start_loc .cpu [: batch_size + 1 ],
125- seq_lens = input_buffers .seq_lens ,
126- seq_lens_np = np .full (batch_size , self .max_model_len , dtype = np .int32 ),
127- num_computed_tokens_cpu = None , # FIXME
128- block_tables = input_block_tables ,
129- slot_mappings = slot_mappings ,
130- kv_cache_config = kv_cache_config ,
80+ num_reqs = min (num_tokens , self .max_num_reqs )
81+ input_ids = input_buffers .input_ids .gpu [:num_tokens ]
82+ positions = input_buffers .positions [:num_tokens ]
83+ attn_metadata = prepare_inputs_to_capture (
84+ num_reqs ,
85+ num_tokens ,
86+ input_buffers ,
87+ block_tables ,
88+ attn_metadata_builders ,
89+ self .max_model_len ,
90+ kv_cache_config ,
13191 )
132- num_tokens_across_dp = make_num_tokens_across_dp (self .dp_size , batch_size )
92+ num_tokens_across_dp = make_num_tokens_across_dp (self .dp_size , num_tokens )
13393
13494 # Warm up.
13595 with set_forward_context (
13696 attn_metadata ,
13797 self .vllm_config ,
138- num_tokens = batch_size ,
98+ num_tokens = num_tokens ,
13999 cudagraph_runtime_mode = CUDAGraphMode .NONE ,
140100 num_tokens_across_dp = num_tokens_across_dp ,
141101 ):
@@ -147,13 +107,13 @@ def capture_graph(
147107 self .hidden_states = torch .empty_like (hidden_states )
148108
149109 # Capture the graph.
110+ assert num_tokens not in self .graphs
150111 graph = torch .cuda .CUDAGraph ()
151112 with (
152- patch ("torch.cuda.empty_cache" , lambda : None ),
153113 set_forward_context (
154114 attn_metadata ,
155115 self .vllm_config ,
156- num_tokens = batch_size ,
116+ num_tokens = num_tokens ,
157117 cudagraph_runtime_mode = CUDAGraphMode .NONE ,
158118 num_tokens_across_dp = num_tokens_across_dp ,
159119 ),
@@ -163,8 +123,8 @@ def capture_graph(
163123 input_ids = input_ids ,
164124 positions = positions ,
165125 )
166- self .hidden_states [:batch_size ] = hidden_states
167- self .graphs [batch_size ] = graph
126+ self .hidden_states [:num_tokens ] = hidden_states
127+ self .graphs [num_tokens ] = graph
168128
169129 @torch .inference_mode ()
170130 def capture (
@@ -175,25 +135,124 @@ def capture(
175135 attn_metadata_builders : list [AttentionMetadataBuilder ],
176136 kv_cache_config : KVCacheConfig ,
177137 ) -> None :
178- assert self .needs_capture ()
179- # Capture larger graphs first.
180- sizes_to_capture = sorted (self .cudagraph_sizes , reverse = True )
181- if is_global_first_rank ():
182- sizes_to_capture = tqdm (sizes_to_capture , desc = "Capturing CUDA graphs" )
183-
184- with graph_capture (device = self .device ):
185- for batch_size in sizes_to_capture :
186- self .capture_graph (
187- batch_size ,
188- model ,
189- input_buffers ,
190- block_tables ,
191- attn_metadata_builders ,
192- kv_cache_config ,
193- )
194-
195- def run (self , batch_size : int ) -> torch .Tensor :
196- assert batch_size in self .graphs
197- self .graphs [batch_size ].replay ()
138+ capture_graphs (
139+ self .cudagraph_sizes ,
140+ self .device ,
141+ self .capture_graph ,
142+ model = model ,
143+ input_buffers = input_buffers ,
144+ block_tables = block_tables ,
145+ attn_metadata_builders = attn_metadata_builders ,
146+ kv_cache_config = kv_cache_config ,
147+ )
148+
149+ def run (self , num_tokens : int ) -> torch .Tensor :
150+ assert num_tokens in self .graphs
151+ self .graphs [num_tokens ].replay ()
198152 assert self .hidden_states is not None
199- return self .hidden_states [:batch_size ]
153+ return self .hidden_states [:num_tokens ]
154+
155+
156+ def get_cudagraph_sizes (
157+ capture_sizes : list [int ] | None ,
158+ max_num_reqs : int ,
159+ max_num_tokens : int ,
160+ cudagraph_mode : CUDAGraphMode ,
161+ ) -> dict [int , int ]:
162+ if not cudagraph_mode .has_full_cudagraphs ():
163+ return {}
164+ if not capture_sizes :
165+ return {}
166+
167+ capture_sizes = sorted (capture_sizes )
168+ # Limit the capture sizes to the max number of requests or tokens.
169+ upper_bound = (
170+ max_num_reqs
171+ if cudagraph_mode == CUDAGraphMode .FULL_DECODE_ONLY
172+ else max_num_tokens
173+ )
174+ capture_sizes = [x for x in capture_sizes if x <= upper_bound ]
175+ if not capture_sizes :
176+ return {}
177+
178+ cudagraph_sizes : dict [int , int ] = {}
179+ for i in range (1 , capture_sizes [- 1 ] + 1 ):
180+ for x in capture_sizes :
181+ if i <= x :
182+ cudagraph_sizes [i ] = x
183+ break
184+ return cudagraph_sizes
185+
186+
187+ def get_cudagraph_size (
188+ num_tokens_after_dp_padding : int ,
189+ num_tokens_per_request : Iterable [int ],
190+ cudagraph_sizes : dict [int , int ],
191+ cudagraph_mode : CUDAGraphMode ,
192+ ) -> int | None :
193+ size = cudagraph_sizes .get (num_tokens_after_dp_padding )
194+ if size is None :
195+ # No CUDA graph for this size.
196+ return None
197+ if cudagraph_mode == CUDAGraphMode .FULL_DECODE_ONLY :
198+ all_decode = all (x == 1 for x in num_tokens_per_request )
199+ if not all_decode :
200+ # Prefill is included.
201+ return None
202+ return size
203+
204+
205+ def capture_graphs (
206+ cudagraph_sizes : dict [int , int ],
207+ device : torch .device ,
208+ capture_fn : Callable ,
209+ ** capture_kwargs ,
210+ ) -> None :
211+ # Capture larger graphs first.
212+ sizes_to_capture = sorted (set (cudagraph_sizes .values ()), reverse = True )
213+ if is_global_first_rank ():
214+ sizes_to_capture = tqdm (sizes_to_capture , desc = "Capturing CUDA graphs" )
215+
216+ with graph_capture (device = device ):
217+ for size in sizes_to_capture :
218+ capture_fn (size , ** capture_kwargs )
219+
220+
221+ def prepare_inputs_to_capture (
222+ num_reqs : int ,
223+ num_tokens : int ,
224+ input_buffers : InputBuffers ,
225+ block_tables : BlockTables ,
226+ attn_metadata_builders : list [AttentionMetadataBuilder ],
227+ max_model_len : int ,
228+ kv_cache_config : KVCacheConfig ,
229+ ) -> dict [str , Any ]:
230+ num_tokens_per_req = num_tokens // num_reqs
231+ query_start_loc = input_buffers .query_start_loc
232+ query_start_loc .np [: num_reqs + 1 ] = np .arange (num_reqs + 1 ) * num_tokens_per_req
233+ query_start_loc .np [num_reqs :] = num_tokens
234+ query_start_loc .copy_to_gpu ()
235+ seq_lens_np = np .full (num_reqs , max_model_len , dtype = np .int32 )
236+ # HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
237+ # for seq_lens. This leads to a mismatch between seq_lens (GPU) and
238+ # seq_lens_np (CPU), which might cause issues in some attention backends.
239+ input_buffers .seq_lens [:num_reqs ] = 1
240+ input_buffers .seq_lens [num_reqs :] = 0
241+
242+ input_block_tables = [x [:num_reqs ] for x in block_tables .input_block_tables ]
243+ slot_mappings = block_tables .slot_mappings [:, :num_tokens ]
244+
245+ attn_metadata = build_attn_metadata (
246+ attn_metadata_builders = attn_metadata_builders ,
247+ num_reqs = num_reqs ,
248+ num_tokens = num_tokens ,
249+ query_start_loc_gpu = query_start_loc .gpu [: num_reqs + 1 ],
250+ query_start_loc_cpu = query_start_loc .cpu [: num_reqs + 1 ],
251+ seq_lens = input_buffers .seq_lens ,
252+ seq_lens_np = seq_lens_np ,
253+ num_computed_tokens_cpu = None , # FIXME
254+ block_tables = input_block_tables ,
255+ slot_mappings = slot_mappings ,
256+ kv_cache_config = kv_cache_config ,
257+ )
258+ return attn_metadata
0 commit comments