You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/design/torch_compile.md
+104-1Lines changed: 104 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -29,6 +29,109 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all
29
29
30
30
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
31
31
32
+
## Dynamic shapes and vllm guard dropping
33
+
34
+
`torch.compile` is designed to guard on dynamic shapes with no hesitation
35
+
when needed. This contradicts with vLLM's `torch.compile` approach of
36
+
dropping the guards since many of those guards could be material.
37
+
38
+
`torch.compile` provides two kinds of dynamic shapes: `backed` and `unbacked`.
39
+
`torch.compile` guards on `backed` dynamic shapes and does not provide a
40
+
guarantee that no guards will be added to them. User code, dynamo,
41
+
inductor, and autograd all can add guards. Moreover, for 0/1
42
+
specializations, backed symbols are specialized unconditionally to 0, 1,
43
+
or >=2 even without encountering a branching on those ranges.
44
+
45
+
On the contrary, `unbacked` dynamic shapes are guaranteed not to be guarded
46
+
on and are not 0/1 specialized. However, there is a possibility of
47
+
throwing a data dependent error when a branch that requires their value is
48
+
encountered and no explicit unbacked handling is defined. The framework is
49
+
converging to a state where it won't throw DDE but rather pick general
50
+
paths. One downside of using unbacked is missed optimization opportunities
51
+
due to either perf bugs or picking general paths, also using a fixed
52
+
non-example input-based hint (this will be fixed soon with override_hint
53
+
API). An example of picking general paths is assuming input not contiguous
54
+
in functions call contiguous() and reshape() when can't be symbolically proven
55
+
with a change of introducing a clone.
56
+
57
+
`backed_size_oblivious` is a flag that enables treating backed symbols as
58
+
unbacked wherever explicit handling for unbacked is defined. With this
59
+
mode, 0/1 specializations are mostly avoided in framework code and the
60
+
default 0/1 specialization does not happen. However, there is still no
61
+
guarantee that torch.compile won't guard, especially due to user code or
62
+
custom passes. `backed_size_oblivious` is experimental in PyTorch compile
63
+
and could be deprecated. That said, it's a safer option to use than
64
+
`backed` and the probability of reducing performance is lower than
65
+
`unbacked`.
66
+
67
+
### Configuring Dynamic Shapes
68
+
69
+
The `DynamicShapesConfig` allows you to control the dynamic shapes behavior by
70
+
setting the `type` field. You can choose between three modes:
71
+
`BACKED`(default), `UNBACKED` , and `BACKED_SIZE_OBLIVIOUS`.
72
+
73
+
#### Offline Inference Example (Using LLM class)
74
+
75
+
When using the `LLM` class for offline inference, you can configure dynamic
76
+
shapes through the `compilation_config` parameter:
77
+
78
+
```python
79
+
from vllm importLLM, SamplingParams
80
+
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
81
+
82
+
# Example: Using backed_size_oblivious (experimental, safer than backed)
83
+
llm = LLM(
84
+
model="meta-llama/Llama-3.2-1B",
85
+
compilation_config=CompilationConfig(
86
+
dynamic_shapes_config=DynamicShapesConfig(
87
+
type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS
88
+
)
89
+
)
90
+
)
91
+
92
+
# Example: Using unbacked (strongest guarantee against guards)
93
+
llm = LLM(
94
+
model="meta-llama/Llama-3.2-1B",
95
+
compilation_config=CompilationConfig(
96
+
dynamic_shapes_config=DynamicShapesConfig(
97
+
type=DynamicShapesType.UNBACKED
98
+
)
99
+
)
100
+
)
101
+
102
+
# Generate outputs
103
+
prompts = ["Hello, my name is", "The future of AI is"]
0 commit comments