Skip to content

Commit 1bf7915

Browse files
committed
[main] Support MTP shape with ACLgraph
Signed-off-by: lilinsiman <[email protected]>
1 parent 755b635 commit 1bf7915

File tree

2 files changed

+55
-12
lines changed

2 files changed

+55
-12
lines changed

vllm_ascend/utils.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,51 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
452452
update_cudagraph_capture_sizes(vllm_config,
453453
new_cudagraph_capture_sizes)
454454

455+
# modify the default capture_sizes for num_speculative_tokens >= 1 scenario.
456+
# this is mainly because in the scenario where MTP is superimposed with Full Graph, the FIA operator needs to perform
457+
# padding operations to adapt to its actual_seq_lengths parameter. The padding operation will
458+
# expand each request to the maximum request count under MTP, therefore the input shape must be
459+
# equal to a multiple of the MTP layer count (k+1). Assuming k=2, capture_sizes = [3, 6, 9, 15, 18, ...].
460+
# Consequently, it is necessary to modify the default captured graph shape of Full Graph to
461+
# accommodate this requirement of the FIA operator.
462+
# TODO: It is more appropriate to place the initialization of shape capture for the fullgraph of the FIA
463+
# operator adapted for MTP in the vLLM community. Therefore, this section will be removed
464+
# migrated to the vLLM community.
465+
from vllm.config.compilation import CUDAGraphMode
466+
aclgraph_mode = vllm_config.compilation_config.cudagraph_mode
467+
if vllm_config.speculative_config is not None and \
468+
aclgraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
469+
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
470+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
471+
target_sizes = (num_speculative_tokens + 1) * max_num_seqs
472+
original_sizes, vllm_config.compilation_config.cudagraph_capture_sizes = \
473+
vllm_config.compilation_config.cudagraph_capture_sizes, None
474+
assert len(original_sizes) > 0
475+
assert max_num_seqs > 0
476+
assert num_speculative_tokens > 0
477+
if num_speculative_tokens > 1:
478+
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
479+
new_original_sizes = sorted(set(list(range(1, min(10, max_num_seqs + 1), 2)) + list(range(8, max_num_seqs + 1, 4))))
480+
enlarged_sizes = [(num_speculative_tokens + 1) * sizes for sizes in new_original_sizes]
481+
if enlarged_sizes[-1] < target_sizes:
482+
enlarged_sizes.append(target_sizes)
483+
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
484+
logger.info(
485+
"Adjusted ACL full graphs: %s → %s for speculative decoding",
486+
original_sizes, enlarged_sizes)
487+
else:
488+
vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes
489+
if num_speculative_tokens == 1:
490+
padding_sizes = original_sizes.copy()
491+
if padding_sizes[-1] < target_sizes:
492+
padding_sizes.append(target_sizes)
493+
update_cudagraph_capture_sizes(vllm_config, padding_sizes)
494+
logger.info(
495+
"Adjusted ACL full graphs: %s → %s for speculative decoding",
496+
original_sizes, padding_sizes)
497+
else:
498+
vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes
499+
455500

456501
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
457502
"""Update ACL graph capture sizes based on hardware limitations"""
@@ -571,13 +616,21 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
571616
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
572617
original_sizes, compilation_config.cudagraph_capture_sizes = \
573618
compilation_config.cudagraph_capture_sizes, None
619+
new_original_sizes = sorted(set(list(range(1, min(10, max_num_seqs + 1), 2)) + list(range(8, max_num_seqs + 1, 4))))
620+
step = (len(new_original_sizes) - 1) / (max_num_batch_sizes - 1)
621+
indices = [round(i * step) for i in range(max_num_batch_sizes)]
622+
indices[0], indices[-1] = 0, len(new_original_sizes) - 1
623+
new_sampled_sizes = [new_original_sizes[i] for i in indices]
624+
target_sizes = (num_speculative_tokens + 1) * max_num_seqs
574625
assert len(original_sizes) > 0
575626
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
576627
enlarged_sizes = [(num_speculative_tokens + 1) * size
577-
for size in original_sizes]
628+
for size in new_sampled_sizes]
629+
if enlarged_sizes[-1] < target_sizes:
630+
enlarged_sizes[-1] = target_sizes
578631
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
579632
logger.info(
580-
"Adjusted ACL graphs: %s → %s for speculative decoding",
633+
"Adjusted PieceWise ACL graphs: %s → %s for speculative decoding",
581634
original_sizes, enlarged_sizes)
582635
else:
583636
compilation_config.cudagraph_capture_sizes = original_sizes

vllm_ascend/worker/model_runner_v1.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4010,16 +4010,6 @@ def _capture_model(self):
40104010
and x >= self.uniform_decode_query_len
40114011
]
40124012
compilation_cases_decode = sorted(decode_cudagraph_batch_sizes)
4013-
# TODO: refactor this when vLLM supports mtp>1
4014-
if not all(x % self.uniform_decode_query_len == 0
4015-
for x in decode_cudagraph_batch_sizes):
4016-
raise ValueError(
4017-
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
4018-
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
4019-
f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, "
4020-
f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. "
4021-
"For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
4022-
)
40234013
self._capture_aclgraphs(
40244014
compilation_cases=compilation_cases_decode,
40254015
aclgraph_runtime_mode=CUDAGraphMode.FULL,

0 commit comments

Comments
 (0)