Skip to content

Commit 8fdb689

Browse files
yiz-liuwangxiyuan
andauthored
[BugFix] Refactor ACL graph size adjustment for speculative decoding (#4640)
### What this PR does / why we need it? Move the logic for adjusting ACL graph capture sizes for speculative decoding from the generic utility module into a dedicated method within the compilation configuration. This change improves code organization and encapsulation by making the compilation configuration responsible for managing its own state. The model runner now triggers this adjustment directly, providing the necessary context. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: Yizhou Liu <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent 688b133 commit 8fdb689

File tree

2 files changed

+12
-31
lines changed

2 files changed

+12
-31
lines changed

vllm_ascend/utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -571,26 +571,6 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
571571
vllm_config.model_config.architectures[0], num_hidden_layers,
572572
len(original_sizes))
573573

574-
# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
575-
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
576-
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
577-
if vllm_config.speculative_config is not None and \
578-
vllm_config.speculative_config.num_speculative_tokens > 1:
579-
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
580-
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
581-
original_sizes, compilation_config.cudagraph_capture_sizes = \
582-
compilation_config.cudagraph_capture_sizes, None
583-
assert len(original_sizes) > 0
584-
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
585-
enlarged_sizes = [(num_speculative_tokens + 1) * size
586-
for size in original_sizes]
587-
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
588-
logger.info(
589-
"Adjusted ACL graphs: %s → %s for speculative decoding",
590-
original_sizes, enlarged_sizes)
591-
else:
592-
compilation_config.cudagraph_capture_sizes = original_sizes
593-
594574

595575
# TODO(wxy): Move to ops module
596576
def dispose_tensor(x: torch.Tensor):

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4027,6 +4027,16 @@ def initialize_aclgraph_capture(self) -> None:
40274027
"; please try cudagraph_mode=PIECEWISE, "
40284028
"and make sure compilation level is piecewise")
40294029

4030+
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
4031+
and aclgraph_mode.separate_routine()
4032+
and self.uniform_decode_query_len > 1):
4033+
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
4034+
self.uniform_decode_query_len,
4035+
self.parallel_config.tensor_parallel_size)
4036+
capture_sizes = self.compilation_config.cudagraph_capture_sizes
4037+
self.aclgraph_batch_sizes = (capture_sizes
4038+
if capture_sizes is not None else [])
4039+
40304040
self.aclgraph_dispatcher.initialize_cudagraph_keys(
40314041
self.compilation_config.cudagraph_mode,
40324042
self.uniform_decode_query_len)
@@ -4122,17 +4132,8 @@ def _capture_model(self):
41224132
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
41234133
and x >= self.uniform_decode_query_len
41244134
]
4125-
compilation_cases_decode = sorted(decode_cudagraph_batch_sizes)
4126-
# TODO: refactor this when vLLM supports mtp>1
4127-
if not all(x % self.uniform_decode_query_len == 0
4128-
for x in decode_cudagraph_batch_sizes):
4129-
raise ValueError(
4130-
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
4131-
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
4132-
f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, "
4133-
f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. "
4134-
"For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
4135-
)
4135+
compilation_cases_decode = list(
4136+
reversed(decode_cudagraph_batch_sizes))
41364137
self._capture_aclgraphs(
41374138
compilation_cases=compilation_cases_decode,
41384139
aclgraph_runtime_mode=CUDAGraphMode.FULL,

0 commit comments

Comments
 (0)