Skip to content

Commit 7a228b5

Browse files
authored
Add option to use unbacked, and backed size obl dynamic shapes for more sounds compilation. (vllm-project#26199)
Signed-off-by: Laith Sakka <[email protected]>
1 parent f716a15 commit 7a228b5

File tree

8 files changed

+442
-15
lines changed

8 files changed

+442
-15
lines changed

docs/design/debug_vllm_compile.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,76 @@ To avoid this, please either:
151151
2. wrap the branching logic into a custom operator. TorchDynamo does not
152152
trace into custom operators.
153153

154+
## Debugging constraint violations and dynamic shapes guards issues
155+
156+
Dynamic-shape guards are a specific category of Dynamo guards. They are constraints that `torch.compile`
157+
attaches to dynamic dimensions (e.g., `seq_len`) to ensure the compiled artifact remains valid.
158+
These guards typically appear when framework code, custom passes, or user code branches based on
159+
dynamic shape values.
160+
161+
**Example:**
162+
163+
```python
164+
if x > 10:
165+
# path A
166+
else:
167+
# path B
168+
```
169+
170+
This creates a guard `x > 10` or `x <= 10` depending on which path was traced.
171+
172+
**vLLM's Assumption:**
173+
vLLM assumes that all guards added by torch.compile are safe to drop and will not
174+
constrain the compiled graph to specific input shapes. When this assumption is violated,
175+
it can cause issues that users need to debug.
176+
Some side effects that indicates this assumption is violated are runtime errors
177+
or `ConstraintViolationErrors`.
178+
179+
A `ConstraintViolationErrors` will be thrown if a dynamic shape gets constrained to
180+
a single value. If you encounter a constraint violation error or suspect that a dynamic
181+
shapes guard is being added incorrectly, you can use stricter dynamic shape modes to
182+
help debug the issue:
183+
184+
```sh
185+
# Online - using unbacked mode
186+
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
187+
188+
# Online - using backed_size_oblivious mode
189+
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=backed_size_oblivious
190+
```
191+
192+
```py
193+
# Offline - using unbacked mode
194+
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
195+
LLM(model, compilation_config=CompilationConfig(
196+
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.UNBACKED)
197+
))
198+
199+
# Offline - using backed_size_oblivious mode
200+
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
201+
LLM(model, compilation_config=CompilationConfig(
202+
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS)
203+
))
204+
```
205+
206+
These modes are stricter and reduce or eliminate the need of dynamic shapes guarding, which can help isolate issues:
207+
208+
- `unbacked`: Uses unbacked symints which don't allow guards, making it easier to identify where guards are being incorrectly added
209+
- `backed_size_oblivious`: Uses a mode that is more strict about guarding.
210+
211+
For more details on dynamic shapes modes, see [Dynamic shapes and vLLM guard dropping](torch_compile.md#dynamic-shapes-and-vllm-guard-dropping).
212+
213+
### Printing guards
214+
215+
To see all guards that are being added during compilation, you can use `TORCH_LOGS=+dynamic`:
216+
217+
```sh
218+
TORCH_LOGS=+dynamic vllm serve meta-llama/Llama-3.2-1B
219+
```
220+
221+
Look for `[guard added]` in the logs to see where guards are being added. This can help you identify which operations are
222+
causing guards to be added incorrectly.
223+
154224
## Debugging TorchInductor
155225

156226
TorchInductor takes a captured graph and then compiles it down to some Python code

docs/design/torch_compile.md

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,109 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all
2929

3030
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
3131

32+
## Dynamic shapes and vllm guard dropping
33+
34+
`torch.compile` is designed to guard on dynamic shapes with no hesitation
35+
when needed. This contradicts with vLLM's `torch.compile` approach of
36+
dropping the guards since many of those guards could be material.
37+
38+
`torch.compile` provides two kinds of dynamic shapes: `backed` and `unbacked`.
39+
`torch.compile` guards on `backed` dynamic shapes and does not provide a
40+
guarantee that no guards will be added to them. User code, dynamo,
41+
inductor, and autograd all can add guards. Moreover, for 0/1
42+
specializations, backed symbols are specialized unconditionally to 0, 1,
43+
or >=2 even without encountering a branching on those ranges.
44+
45+
On the contrary, `unbacked` dynamic shapes are guaranteed not to be guarded
46+
on and are not 0/1 specialized. However, there is a possibility of
47+
throwing a data dependent error when a branch that requires their value is
48+
encountered and no explicit unbacked handling is defined. The framework is
49+
converging to a state where it won't throw DDE but rather pick general
50+
paths. One downside of using unbacked is missed optimization opportunities
51+
due to either perf bugs or picking general paths, also using a fixed
52+
non-example input-based hint (this will be fixed soon with override_hint
53+
API). An example of picking general paths is assuming input not contiguous
54+
in functions call contiguous() and reshape() when can't be symbolically proven
55+
with a change of introducing a clone.
56+
57+
`backed_size_oblivious` is a flag that enables treating backed symbols as
58+
unbacked wherever explicit handling for unbacked is defined. With this
59+
mode, 0/1 specializations are mostly avoided in framework code and the
60+
default 0/1 specialization does not happen. However, there is still no
61+
guarantee that torch.compile won't guard, especially due to user code or
62+
custom passes. `backed_size_oblivious` is experimental in PyTorch compile
63+
and could be deprecated. That said, it's a safer option to use than
64+
`backed` and the probability of reducing performance is lower than
65+
`unbacked`.
66+
67+
### Configuring Dynamic Shapes
68+
69+
The `DynamicShapesConfig` allows you to control the dynamic shapes behavior by
70+
setting the `type` field. You can choose between three modes:
71+
`BACKED`(default), `UNBACKED` , and `BACKED_SIZE_OBLIVIOUS`.
72+
73+
#### Offline Inference Example (Using LLM class)
74+
75+
When using the `LLM` class for offline inference, you can configure dynamic
76+
shapes through the `compilation_config` parameter:
77+
78+
```python
79+
from vllm import LLM, SamplingParams
80+
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
81+
82+
# Example: Using backed_size_oblivious (experimental, safer than backed)
83+
llm = LLM(
84+
model="meta-llama/Llama-3.2-1B",
85+
compilation_config=CompilationConfig(
86+
dynamic_shapes_config=DynamicShapesConfig(
87+
type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS
88+
)
89+
)
90+
)
91+
92+
# Example: Using unbacked (strongest guarantee against guards)
93+
llm = LLM(
94+
model="meta-llama/Llama-3.2-1B",
95+
compilation_config=CompilationConfig(
96+
dynamic_shapes_config=DynamicShapesConfig(
97+
type=DynamicShapesType.UNBACKED
98+
)
99+
)
100+
)
101+
102+
# Generate outputs
103+
prompts = ["Hello, my name is", "The future of AI is"]
104+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
105+
outputs = llm.generate(prompts, sampling_params)
106+
```
107+
108+
#### Online Serving Example (Using vllm serve)
109+
110+
When using `vllm serve` for online serving, you can configure dynamic shapes
111+
through the `--compilation-config` flag:
112+
113+
```bash
114+
# Example: Using unbacked
115+
vllm serve meta-llama/Llama-3.2-1B \
116+
--compilation-config '{"dynamic_shapes_config": {"type": "unbacked"}}'
117+
118+
119+
# Alternative: Using dot notation (simpler for single values)
120+
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
121+
```
122+
123+
#### Choosing the Right Mode
124+
125+
- **BACKED** (default): Use when you're willing to accept potential unsafe dropping of guards
126+
for maximal performance. Guard could be unsoundly added and then ignored.
127+
128+
- **UNBACKED** Use when you need the strongest guarantee against guards.
129+
This is the most conservative option but may miss some optimization opportunities.
130+
131+
- **BACKED_SIZE_OBLIVIOUS**: Use when you want a balance between avoiding guards
132+
and performance. This experimental mode is safer than BACKED but still not as
133+
conservative as UNBACKED.
134+
32135
## Python Code Compilation
33136

34137
In the very verbose logs, we can see:
@@ -122,7 +225,7 @@ When all the shapes are known, `torch.compile` can compare different configs, an
122225
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
123226
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
124227
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
125-
mm 0.0160 ms 81.6%
228+
mm 0.0160 ms 81.6%
126229
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
127230
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
128231
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import gc
5+
6+
import pytest
7+
import torch
8+
9+
from vllm import LLM, SamplingParams
10+
from vllm.config.compilation import CompilationMode, DynamicShapesType
11+
from vllm.transformers_utils.tokenizer import get_tokenizer
12+
from vllm.utils.torch_utils import is_torch_equal_or_newer
13+
14+
15+
def get_test_models():
16+
"""Get list of models to test based on PyTorch version"""
17+
# TODO "Qwen/Qwen3-4B-Instruct-2507" fails Fix issue and support it.
18+
return ["gpt2", "Qwen/Qwen2-7B-Instruct", "meta-llama/Llama-3.1-8B"]
19+
20+
21+
@pytest.mark.parametrize("model_name", get_test_models())
22+
@pytest.mark.parametrize(
23+
"shapes_type",
24+
[
25+
DynamicShapesType.BACKED,
26+
DynamicShapesType.UNBACKED,
27+
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
28+
],
29+
)
30+
@pytest.mark.parametrize("use_aot_compile", ["0"])
31+
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
32+
@pytest.mark.skipif(
33+
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
34+
)
35+
def test_dynamic_shapes_compilation(
36+
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
37+
):
38+
"""Test that all dynamic shapes types compile successfully"""
39+
print(
40+
f"\nTesting model: {model_name} with {shapes_type.name}, "
41+
f"AOT compile: {use_aot_compile}, "
42+
f"Bytecode hook: {use_bytecode_hook}"
43+
)
44+
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
45+
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
46+
47+
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
48+
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
49+
50+
prompt = "Hello, my name is"
51+
52+
print(f"Testing {shapes_type.name} dynamic shapes...")
53+
54+
# Initialize the model with specific dynamic shapes configuration
55+
model = LLM(
56+
model=model_name,
57+
compilation_config={
58+
"mode": CompilationMode.VLLM_COMPILE,
59+
"dynamic_shapes_config": {
60+
"type": shapes_type.value,
61+
},
62+
},
63+
)
64+
65+
output = model.generate(prompt)
66+
result = output[0].outputs[0].text
67+
# Example of setting the sampling parameters
68+
tokenizer = get_tokenizer(model_name)
69+
yes_tokens = tokenizer.encode("yes", add_special_tokens=False)
70+
no_tokens = tokenizer.encode("no", add_special_tokens=False)
71+
allowed_ids = list(set(yes_tokens + no_tokens))
72+
sampling_params = SamplingParams(
73+
max_tokens=1, temperature=0, allowed_token_ids=allowed_ids
74+
)
75+
76+
output = model.generate(
77+
"answer with yes or no is " + result + " rubbish for prompt " + prompt + "?",
78+
sampling_params=sampling_params,
79+
)
80+
result = output[0].outputs[0].text
81+
assert result == "yes"
82+
83+
# Clean up GPU memory
84+
del model
85+
gc.collect()
86+
torch.cuda.empty_cache()
87+
torch.cuda.synchronize()
88+
print("GPU memory cleared")

0 commit comments

Comments
 (0)