Skip to content

Commit d9249c9

Browse files
authored
bugfix for mtp in fullgraph (#3878)
### What this PR does / why we need it? bugfix for mtp in fullgraph ### Does this PR introduce _any_ user-facing change? no --------- Signed-off-by: zouyida2052 <[email protected]>
1 parent 19f49ec commit d9249c9

File tree

6 files changed

+58
-38
lines changed

6 files changed

+58
-38
lines changed

docs/source/community/versioning_policy.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ vLLM Ascend includes two branches: main and dev.
7474
Commits should typically be merged into the main branch first, and only then backported to the dev branch, to reduce maintenance costs as much as possible.
7575

7676
### Maintenance branch and EOL
77-
The table below lists branch states.
77+
The table below lists branch states.
7878

7979
| Branch | Time Frame | Summary |
8080
| ----------------- | -------------------------------- | --------------------------------------------------------- |

docs/source/developer_guide/feature_guide/ModelRunner_prepare_inputs.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ As the maximum number of tokens that can be schedules is 10, the scheduled token
9292
##### 1. Get token positions:
9393
First, determine which request each token belongs to: tokens 0–2 are assigned to **request_0**, tokens 3–4 to **request_1**, and tokens 5–9 to **request_2**. To represent this mapping, we use `request indices`, for example, `request indices`: `[0, 0, 0, 1, 1, 2, 2, 2, 2, 2]`.
9494

95-
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
95+
For each request, use **the number of computed tokens** + **the relative position of current scheduled tokens** (`request_0: [0 + 0, 0 + 1, 0 + 2]`, `request_1: [0 + 0, 0 + 1]`, `request_2: [0 + 0, 0 + 1,..., 0 + 4]`) and then concatenate them together (`[0, 1, 2, 0, 1, 0, 1, 2, 3, 4]`).
9696

9797
Note: there is more efficient way (using `request indices`) to create positions in actual code.
9898

@@ -152,33 +152,33 @@ The KV cache block in the device memory is like:
152152
Let's say `K = max model len / block size = 6`, and we can get token `device block number`.
153153

154154
The workflow of achieving slot mapping:
155-
1. Get `block table indices` using `K`, `positions` and `request indices`.
155+
1. Get `block table indices` using `K`, `positions` and `request indices`.
156156

157157
Purpose: For each token, it could be used to select `device block number` from `block table`.
158158

159-
2. Get `device block number` using `block table indices`.
159+
2. Get `device block number` using `block table indices`.
160160

161161
Purpose: `device block number` indicates which device block each token belongs to.
162162

163-
3. Get `block offsets` using `positions` and `block size`.
163+
3. Get `block offsets` using `positions` and `block size`.
164164

165165
Purpose: `block offsets` indicates the offsets of each token within a block.
166166

167-
4. construct `slot mapping` using `device block number` and `block offsets`.
167+
4. construct `slot mapping` using `device block number` and `block offsets`.
168168

169169
Purpose: we can use `slot mapping` to store Token IDs into token slots.
170170

171171
Details:
172-
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
172+
1. (**Token level**) Use a simple formula to calculate `block table indices`: `request indices * K + positions / block size`. So it equal to `[0 * 6 + 0 / 2, 0 * 6 + 1 / 2, 0 * 6 + 2 / 2, 1 * 6 + 0 / 2, 1 * 6 + 1 / 2, 2 * 6 + 0 / 2, 2 * 6 + 1 / 2, 2 * 6 + 2 / 2, 2 * 6 + 3 / 2, 2 * 6 + 4 / 2] = [0, 0, 1, 6, 6, 12, 12, 13, 13, 14]`. This could be used to select `device block number` from `block table`.
173173
2. (**Token level**) Use `block table indices` to select out `device block number` for each scheduled token. The Pseudocode is `block_numbers = block_table[block_table_indices]`. So `device block number=[1, 1, 2, 3, 3, 4, 4, 5, 5, 6]`
174-
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
174+
3. (**Token level**) `block offsets` could be computed by `block offsets = positions % block size = [0, 1, 0, 0, 1, 0, 1, 0, 1, 0]`.
175175
4. At last, use `block offsets` and `device block number` to create `slot mapping`: `device block number * block size + block_offsets = [2, 3, 4, 6, 7, 8, 9, 10, 11, 12]`
176176

177-
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
177+
(**Request level**) As we know the scheduled token count is `[3, 2, 5]`:
178178

179-
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
180-
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
181-
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
179+
- (**Request level**) Use prefix sum to calculate `query start location`: `[0, 3, 5, 10]`.
180+
- (**Request level**) All tokens in step 1 are in the prefill stage, and the computed tokens count is 0; then `sequence length` = `[3, 2, 5]`.
181+
- (**Request level**) As mentioned above, `number of computed tokens` are all 0s: `[0, 0, 0]`.
182182
- `number of requests`: `3`
183183
- (**Request level**) `number of tokens`: `[3, 2, 5]`
184184
- `max query len`: `5`
@@ -235,7 +235,7 @@ KV cache block in the device memory:
235235
1. (**Token level**) `block table indices`: `[1, 7, 14, 15, 15]`
236236
2. (**Token level**) `device block number`: `[2, 7, 6, 8, 8]`
237237
3. (**Token level**) `block offsets`: `[1, 0, 1, 0, 1]`
238-
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
238+
4. (**Token level**) `slot mapping`: `[5, 14, 13, 16, 17]`
239239

240240
Scheduled token count:`[1, 1, 3]`
241241
- `query start location`: `[0, 1, 2, 5]`
@@ -250,7 +250,7 @@ Scheduled token count:`[1, 1, 3]`
250250

251251
- `slot mapping`: `[5, 14, 13, 16, 17]`
252252

253-
- `attention mask`: `5 * 8`
253+
- `attention mask`: `5 * 8`
254254

255255
Each token has a `1 * 8` vector, and there are 5 scheduled tokens.
256256

docs/source/user_guide/support_matrix/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160
8080
| GLM-4V || [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) |
8181
| InternVL2.0/2.5/3.0<br>InternVideo2.5/Mono-InternVL || [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) |
8282
| Whisper || [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) |
83-
| Ultravox | 🟡 | Need test |
83+
| Ultravox | 🟡 | Need test |

vllm_ascend/platform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
263263
**********************************************************************************\033[0m
264264
"""
265265
logger.warning(warning_message)
266+
update_aclgraph_sizes(vllm_config)
266267
else:
267268
logger.info(
268269
"%s cudagraph_mode is not support on NPU. falling back to NONE",

vllm_ascend/utils.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,13 @@ def _rec_find(d):
314314

315315
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
316316
"""Update ACL graph capture sizes based on hardware limitations"""
317+
from vllm.config.compilation import CUDAGraphMode
318+
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
319+
if vllm_config.speculative_config is not None and \
320+
vllm_config.speculative_config.num_speculative_tokens > 1:
321+
_update_spec_aclgraph_sizes(vllm_config)
322+
return
323+
317324
# NOTE: Currently, we can only capture 1800 graphs at most,
318325
# due to the limitation of ACL graph. This number is bounded by
319326
# the number of streams, which is 2048, we save 248 streams
@@ -421,25 +428,43 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
421428
vllm_config.model_config.architectures[0], num_hidden_layers,
422429
len(original_sizes))
423430

431+
if vllm_config.speculative_config is not None and \
432+
vllm_config.speculative_config.num_speculative_tokens > 1:
433+
_update_spec_aclgraph_sizes(vllm_config)
434+
435+
436+
def _update_spec_aclgraph_sizes(vllm_config: VllmConfig) -> None:
424437
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
425438
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
426439
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
427-
if vllm_config.speculative_config is not None and \
428-
vllm_config.speculative_config.num_speculative_tokens > 1:
429-
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
430-
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
431-
original_sizes, compilation_config.cudagraph_capture_sizes = \
432-
compilation_config.cudagraph_capture_sizes, None
433-
assert len(original_sizes) > 0
434-
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
435-
enlarged_sizes = [(num_speculative_tokens + 1) * size
436-
for size in original_sizes]
437-
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
438-
logger.info(
439-
"Adjusted ACL graphs: %s → %s for speculative decoding",
440-
original_sizes, enlarged_sizes)
441-
else:
442-
compilation_config.cudagraph_capture_sizes = original_sizes
440+
from vllm.config.compilation import CUDAGraphMode
441+
compilation_config = vllm_config.compilation_config
442+
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
443+
uniform_decode_query_len = num_speculative_tokens + 1
444+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
445+
max_num_tokens = max_num_seqs * uniform_decode_query_len
446+
original_sizes, compilation_config.cudagraph_capture_sizes = \
447+
compilation_config.cudagraph_capture_sizes, None
448+
assert len(original_sizes) > 0
449+
450+
if vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and \
451+
not all(size % uniform_decode_query_len == 0 for size in original_sizes):
452+
enlarged_sizes = [
453+
size * uniform_decode_query_len for size in original_sizes
454+
if max_num_tokens >= size >= uniform_decode_query_len
455+
]
456+
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
457+
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
458+
original_sizes, enlarged_sizes)
459+
elif original_sizes[0] < max_num_tokens:
460+
enlarged_sizes = [
461+
size * uniform_decode_query_len for size in original_sizes
462+
]
463+
compilation_config.init_with_cudagraph_sizes(enlarged_sizes)
464+
logger.info("Adjusted ACL graphs: %s → %s for speculative decoding",
465+
original_sizes, enlarged_sizes)
466+
else:
467+
compilation_config.cudagraph_capture_sizes = original_sizes
443468

444469

445470
# TODO(wxy): Move to ops module

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3529,14 +3529,8 @@ def _capture_model(self):
35293529

35303530
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
35313531
aclgraph_mode.separate_routine():
3532-
max_num_tokens = self.scheduler_config.max_num_seqs * \
3533-
self.uniform_decode_query_len
3534-
decode_cudagraph_batch_sizes = [
3535-
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
3536-
and x >= self.uniform_decode_query_len
3537-
]
35383532
compilation_cases_decode = list(
3539-
reversed(decode_cudagraph_batch_sizes))
3533+
reversed(self.aclgraph_batch_sizes))
35403534
self._capture_aclgraphs(
35413535
compilation_cases=compilation_cases_decode,
35423536
aclgraph_runtime_mode=CUDAGraphMode.FULL,

0 commit comments

Comments
 (0)