Skip to content

Commit 8bc3047

Browse files
committed
[docs] add aclgraph developer guide
Signed-off-by: zzzzwwjj <[email protected]>
1 parent f2dd5f8 commit 8bc3047

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# ACLGraph
2+
3+
## Why we need ACLGraph?
4+
5+
When in LLM inference, each token requires nearly thousand operator executions, and when CPU launching operators are slower than GPU, it will cause host bound. In severe cases, the GPU will be idle for more than half of the time. To solve this problem, we use graph in LLM inference.
6+
7+
```
8+
eager mode:
9+
10+
cpu: | launch op1 | launch op2 | launch op3 | launch op4 | launch op5 |
11+
12+
gpu: | run op1 |free| run op2 |free| run op3 |free| run op4 |free| run op5 |
13+
14+
| <----- total time -----> |
15+
16+
graph mode:
17+
18+
cpu: | launch graph |
19+
20+
gpu: | run op1 | run op2 | run op3 | run op4 | run op5 |
21+
22+
| <----- total time -----> |
23+
24+
```
25+
26+
## How to use ACLGraph?
27+
28+
ACLGraph is enabled by default in V1 Engine, just set to use V1 Engine is enough.
29+
30+
## How it works?
31+
32+
In short, graph mode works in two steps: **capture and replay**. When engine starts, we will capture all of the ops in model forward and save it as a graph, and when req come in, we just replay the graph on gpus, and waiting for result.
33+
34+
But in reality, graph mode is not that simple.
35+
36+
### Padding and Bucketing
37+
38+
Due to graph can only replay the ops captured before, without doing tiling and checking graph input, so we need to ensure the consistency of the graph input, but we know that model input's shape depends on the request scheduled by Scheduler, we can't ensure the consistency.
39+
40+
Obviously, we can solve this problem by capturing the biggest shape and padding all of the model input to it. But it will bring a lot of redundant computing and make performance worse. So we can capture multiple graphs with different shape, and pad the model input to the nearest graph, it will greatly reduce redundant computing, but when `max_num_batched_tokens` is very large, the number of graphs that need to be captured will also become very large. But we know that when intensor's shape is large, the computing time will be very long, and graph mode is not necessary in this case. So all of things we need to do is:
41+
1. Set a threshold;
42+
2. When `num_scheduled_tokens` is bigger than the threshold, use `eager_mode`;
43+
3. Capture multiple graphs within a range below the threshold;
44+
45+
```
46+
| graph1 |
47+
| graph2 |
48+
| graph3 |
49+
| graph4 | # the threshold
50+
51+
| input1 | pad | # use graph1
52+
| input2 | # don't need pad
53+
| input3 | pad | # use graph4
54+
| input4 | # use eager mode
55+
56+
```
57+
58+
### Piecewise and Full graph
59+
60+
Due to the increasing complexity of the attention layer in current LLM, we can't ensure all types of attention can run in graph. In MLA, prefill_tokens and decode_tokens have different calculation method, so when a batch has both prefills and decodes in MLA, graph mode is difficult to handle this situation.
61+
62+
vLLM solves this problem with piecewise graph mode. We use eager mode to launch attention's ops, and use graph to deal with others. But it also bring some problems: The cost of launching ops has become large again, although much smaller than eager mode, but it will also lead to host bound when cpu is poor or `num_tokens` is small.
63+
64+
Altogether, we need to support both piecewise and full graph mode.
65+
66+
1. When attention can run in graph, we tend to choose full graph mode to achieve optimal performance;
67+
2. When full graph is not work, use piecewise graph as a substitute;
68+
3. When piecewise graph's performance is not good and full graph mode is blocked, separate prefills and decodes, and use full graph mode in **decode_only** situation. Because when a batch include prefill req, usually `num_tokens` will be quite big and not cause host bound.
69+
70+
## How it be implemented?
71+
72+
vLLM has already implemented most of the modules in graph mode, and when in graph mode, vLLM will call `current_platform.get_static_graph_wrapper_cls` to get current device's graph model wrapper,so what we need to do is to implement the graph mode wrapper on Ascend: `ACLGraphWrapper`.
73+
74+
vLLM has added `support_torch_compile` decorator to all models, this decorator will replace the `__init__` and `forward` interface of the model class, and when `forward` called, the code inside the `ACLGraphWrapper` will be executed, and it will do capture or replay as mentioned above.
75+
76+
When use piecewise graph, we just need to follow the above-mentioned process, but when in full graph, due to the complexity of the attention, sometimes we need to update attention op's param before execution. So we implement `update_attn_params` and `update_mla_attn_params` funcs for full graph mode. And when forward, memory will be reused between different ops, so we can't update attention op's param before forward. In ACLGraph, we use `torch.npu.graph_task_update_begin` and `torch.npu.graph_task_update_end` to do it, and use `torch.npu.ExternalEvent` to ensure order between params update and ops execution.
77+
78+
## DFX
79+
80+
### Stream resource constraint
81+
82+
When use piecewise graph mode, every sub module will use at least one stream in ACLGraph. Due to stream resource constraint, the number of bucketing will be restricted.
83+
84+
Currently, we calculate the maximum number of bucketing under the current case through a formula:
85+
86+
```python
87+
# NOTE: Currently, we can only capture 1800 graphs at most,
88+
# due to the limitation of ACL graph. This number is bounded by
89+
# the number of streams, which is 2048, we save 248 streams
90+
# as a buffer.
91+
# Maximum number of graphs that can be captured by ACL Graph
92+
# TODO: Find out whether we need to solve allreduce function
93+
MAX_CAPTURE_SIZE = 1800
94+
95+
# Store original configuration and temporarily clear it
96+
compilation_config = vllm_config.compilation_config
97+
original_sizes, compilation_config.cudagraph_capture_sizes = \
98+
compilation_config.cudagraph_capture_sizes, None
99+
100+
# Calculate parallel configuration factor
101+
hf_config = vllm_config.model_config.hf_config
102+
if hasattr(hf_config, 'num_hidden_layers'):
103+
num_hidden_layers = hf_config.num_hidden_layers
104+
else:
105+
num_hidden_layers = get_max_hidden_layers(hf_config)
106+
parallel_config = vllm_config.parallel_config
107+
108+
# Calculate maximum supported batch sizes considering model architecture
109+
resources_per_graph = num_hidden_layers + 1
110+
if vllm_config.speculative_config is not None:
111+
draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config
112+
resources_per_graph += draft_model_hf_config.num_hidden_layers + 1
113+
114+
# TODO: Find out whether we need to take into account the pp_size
115+
num_comm_groups = sum(size > 1 for size in [
116+
parallel_config.data_parallel_size,
117+
parallel_config.tensor_parallel_size,
118+
])
119+
120+
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
121+
# TODO: Find out whether we need to take into account the pp_size
122+
parallel_factor = 1 + num_comm_groups + int(
123+
parallel_config.enable_expert_parallel) + int(
124+
vllm_config.additional_config.get(
125+
"multistream_overlap_shared_expert", False))
126+
if is_moe_model(vllm_config):
127+
parallel_factor += (parallel_config.data_parallel_size > 1)
128+
else:
129+
# When AIV mode is enabled, the allreduce operator of the dense
130+
# layer model will occupy additional streams, which are buffered here.
131+
MAX_CAPTURE_SIZE = MAX_CAPTURE_SIZE - parallel_factor * resources_per_graph
132+
133+
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
134+
# Assume the following case:
135+
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
136+
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
137+
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
138+
resources_per_graph / parallel_factor)
139+
logger.info(
140+
"Calculated maximum supported batch sizes for ACL graph: %s",
141+
max_num_batch_sizes)
142+
else:
143+
# The above describes an empirical formula applicable to the A2 hardware.
144+
# Under this configuration, HCCL employs the FFTS+ method for execution unfolding,
145+
# which adds only 1 concurrent stream without consuming collective communication execution unfolding streams.
146+
# On A3 hardware, HCCL defaults to the AICPU method.
147+
# This approach may additionally allocate up to rank_size (max 16) - 1 streams per collective communication domain on the device (worst case).
148+
# Using the default collective communication unfolding method on A3 will lead to a significant reduction in the maximum supported sizes.
149+
# Therefore, the calculation formula has been modified as follows:
150+
# Assume the following case:
151+
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
152+
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
153+
max_num_batch_sizes = math.floor(
154+
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph /
155+
(1 + num_comm_groups * 2))
156+
logger.info(
157+
"Calculated maximum supported batch sizes for ACL graph: %s",
158+
max_num_batch_sizes)
159+
logger.warning(
160+
"Currently, communication is performed using FFTS+ method, which reduces "
161+
"the number of available streams and, as a result, limits the range of runtime "
162+
"shapes that can be handled. To both improve communication performance and "
163+
"increase the number of supported shapes, set HCCL_OP_EXPANSION_MODE=AIV."
164+
)
165+
```
166+
167+
We will expand the stream resource limitation in the future.
168+
169+
## Limitation
170+
171+
1. `FULL_AND_PIECEWISE` is not supported now;
172+
2. When use ACLGraph and MTP and `num_speculative_tokens > 1`, as vLLM don't support this case in v0.11.0, we need to set `cudagraph_capture_sizes` explicitly.
173+
3. `use_inductor` is not supported now;

docs/source/developer_guide/feature_guide/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ This section provides an overview of the features implemented in vLLM Ascend. De
77
:maxdepth: 1
88
patch
99
ModelRunner_prepare_inputs
10+
ACLGraph
1011
:::

0 commit comments

Comments
 (0)