Skip to content

Commit 165259d

Browse files
committed
refactor: Defer and centralize ACL graph parameter initialization
Moves the setting of ACL graph parameters (`set_graph_params` and `set_mtp_graph_params`) from the initial setup of the model runner and proposer to a later point in the initialization lifecycle. This change ensures that the graph parameters are configured using the definitive `aclgraph_batch_sizes`, which are only determined just before profiling preparation. It also centralizes this configuration logic within the `NPUModelRunner`. Signed-off-by: Yizhou Liu <[email protected]>
1 parent 310e960 commit 165259d

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ class GraphParams:
427427
_graph_params: Optional[GraphParams] = None
428428

429429

430-
def set_graph_params(aclgraph_capture_sizes: set[int]):
430+
def set_graph_params(aclgraph_capture_sizes: list[int]):
431431
global _graph_params
432432
if _graph_params is not None:
433433
raise ValueError("Graph parameters have already been set!")
@@ -456,7 +456,7 @@ def get_graph_params():
456456
_mtp_graph_params: Optional[GraphParams] = None
457457

458458

459-
def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
459+
def set_mtp_graph_params(aclgraph_capture_sizes: list[int]):
460460
global _mtp_graph_params
461461
if _mtp_graph_params is not None:
462462
raise ValueError("MTPGraph parameters have already been set!")

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from vllm_ascend.attention.attention_v1 import AscendAttentionState
3232
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3333
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
34-
set_mtp_graph_params,
3534
update_mla_attn_params)
3635
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
3736
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
@@ -213,8 +212,6 @@ def load_model(self, model) -> None:
213212
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
214213
):
215214
self.update_stream: torch.npu.Stream = torch.npu.Stream()
216-
set_mtp_graph_params(
217-
self.vllm_config.compilation_config.cudagraph_capture_sizes)
218215
self.model = ACLGraphWrapper(self.model,
219216
self.vllm_config,
220217
runtime_mode=CUDAGraphMode.FULL)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
# yapf: disable
124124
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
125125
set_graph_params,
126+
set_mtp_graph_params,
126127
update_attn_dcp_pcp_params,
127128
update_attn_params,
128129
update_mla_attn_dcp_pcp_params,
@@ -3368,7 +3369,6 @@ def load_model(self) -> None:
33683369
# wrap the model with full graph wrapper if needed.
33693370
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
33703371
self.update_stream: torch.npu.Stream = torch.npu.Stream()
3371-
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
33723372
self.model = ACLGraphWrapper(self.model,
33733373
self.vllm_config,
33743374
runtime_mode=CUDAGraphMode.FULL)
@@ -4087,6 +4087,12 @@ def initialize_aclgraph_capture(self) -> None:
40874087
self.aclgraph_batch_sizes = (capture_sizes
40884088
if capture_sizes is not None else [])
40894089

4090+
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
4091+
# we set the graph params right before initializing the keys.
4092+
set_graph_params(self.aclgraph_batch_sizes)
4093+
if self.speculative_config:
4094+
set_mtp_graph_params(self.aclgraph_batch_sizes)
4095+
40904096
self.aclgraph_dispatcher.initialize_cudagraph_keys(
40914097
self.compilation_config.cudagraph_mode,
40924098
self.uniform_decode_query_len)

0 commit comments

Comments
 (0)