Skip to content

Commit 430d371

Browse files
committed
disable mtp graph when use async scheduling
Signed-off-by: Ronald1995 <[email protected]>
1 parent 33c1c56 commit 430d371

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def __init__(
162162
)
163163
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
164164
"index_topk")
165+
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
165166

166167
def load_model(self, model) -> None:
167168
loader = get_model_loader(self.vllm_config.load_config)
@@ -711,7 +712,11 @@ def _propose(
711712
uniform_decode=False)
712713
aclgraph_runtime_mode, batch_descriptor = \
713714
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
714-
715+
if self.use_async_scheduling:
716+
# there is synchronize between mtp steps when enable aclgraph,
717+
# disable aclgraph when use async scheduling to avoid the
718+
# synchronize overhead.
719+
aclgraph_runtime_mode = CUDAGraphMode.NONE
715720
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
716721
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
717722
graph_pad_size = num_input_tokens

0 commit comments

Comments
 (0)