Skip to content

Commit c3acf25

Browse files
author
wangxiaoxin-sherie
committed
support FULL_AND_PIECEWISE graph mode
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 1f25d60 commit c3acf25

File tree

5 files changed

+63
-6
lines changed

5 files changed

+63
-6
lines changed

tests/e2e/multicard/test_full_graph_mode.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from tests.e2e.model_utils import check_outputs_equal
3030

3131

32-
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
32+
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
3333
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
3434
del os.environ['HCCL_OP_EXPANSION_MODE']
3535
prompts = [
@@ -70,3 +70,47 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
7070
name_0="vllm_eager_outputs",
7171
name_1="vllm_fullgraph_outputs",
7272
)
73+
74+
75+
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_AND_PIECEIWSE():
76+
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
77+
del os.environ['HCCL_OP_EXPANSION_MODE']
78+
prompts = [
79+
"Hello, my name is", "The president of the United States is",
80+
"The capital of France is", "The future of AI is"
81+
]
82+
model = "Qwen/Qwen3-30B-A3B"
83+
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
84+
with VllmRunner(
85+
model,
86+
max_model_len=1024,
87+
tensor_parallel_size=2,
88+
enforce_eager=False,
89+
compilation_config={"cudagraph_mode":
90+
"FULL_AND_PIECEIWSE"}) as runner:
91+
vllm_fullgraph_outputs = runner.model.generate(prompts,
92+
sampling_params)
93+
94+
with VllmRunner(
95+
model,
96+
max_model_len=1024,
97+
enforce_eager=True,
98+
) as runner:
99+
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
100+
101+
vllm_fullgraph_outputs_list = []
102+
for output in vllm_fullgraph_outputs:
103+
vllm_fullgraph_outputs_list.append(
104+
(output.outputs[0].index, output.outputs[0].text))
105+
106+
vllm_eager_outputs_list = []
107+
for output in vllm_eager_outputs:
108+
vllm_eager_outputs_list.append(
109+
(output.outputs[0].index, output.outputs[0].text))
110+
111+
check_outputs_equal(
112+
outputs_0_lst=vllm_eager_outputs_list,
113+
outputs_1_lst=vllm_fullgraph_outputs_list,
114+
name_0="vllm_eager_outputs",
115+
name_1="vllm_fullgraph_outputs",
116+
)

vllm_ascend/compilation/acl_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
332332
def update_graph_params_workspaces(num_tokens: int, workspace: int):
333333
global _graph_params
334334
if _graph_params is not None:
335-
_graph_params.workspaces[num_tokens] = workspace
335+
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
336336

337337

338338
def get_graph_params():

vllm_ascend/platform.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
245245
if vllm_version_is("0.11.0"):
246246
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
247247
compilation_config.level = CompilationLevel.NO_COMPILATION
248-
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
248+
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or \
249+
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
249250
logger.info(
250251
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
251252
"using only ACL Graph mode")
@@ -285,7 +286,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
285286
else:
286287
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
287288
compilation_config.mode = CompilationMode.NONE
288-
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
289+
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or \
290+
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
289291
logger.info(
290292
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
291293
"using only ACL Graph mode")

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import types
2222
from typing import Any, Optional
2323

24+
import numpy as np
2425
import torch
2526
import torch.distributed as dist
2627
import torch.nn as nn
@@ -154,6 +155,7 @@ def _build_dummy_attn_metadata(
154155
num_reqs: int,
155156
num_tokens: int,
156157
max_query_len: int,
158+
num_scheduled_tokens: np.ndarray[Any, Any],
157159
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
158160
force_attention: bool = False,
159161
) -> Optional[dict[str, Any]]:
@@ -162,7 +164,7 @@ def _build_dummy_attn_metadata(
162164
if with_prefill or self.enable_shared_expert_dp:
163165
attn_metadata = super()._build_dummy_attn_metadata(
164166
with_prefill, num_reqs, num_tokens, max_query_len,
165-
aclgraph_runtime_mode, force_attention)
167+
num_scheduled_tokens, aclgraph_runtime_mode, force_attention)
166168
else:
167169
common_attn_metadata = TorchairCommonAttentionMetadata(
168170
num_reqs=num_reqs,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
406406
device=self.device)
407407

408408
if self.vllm_config.model_config.use_mla and \
409-
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
409+
(self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or \
410+
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE):
410411
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
411412
self.cos = torch.ones(self.max_num_reqs *
412413
self.decode_token_per_req,
@@ -2486,6 +2487,7 @@ def _build_dummy_attn_metadata(
24862487
num_reqs: int,
24872488
num_tokens: int,
24882489
max_query_len: int,
2490+
num_scheduled_tokens: np.ndarray[Any, Any],
24892491
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
24902492
force_attention: bool = False,
24912493
) -> Optional[dict[str, Any]]:
@@ -2501,6 +2503,12 @@ def _build_dummy_attn_metadata(
25012503
self.seq_lens_np[:num_reqs] = seq_lens
25022504
self.seq_lens_np[num_reqs:] = 0
25032505

2506+
if num_scheduled_tokens is not None:
2507+
cu_num_tokens, arange = self._get_cumsum_and_arange(
2508+
num_scheduled_tokens)
2509+
self.query_start_loc_cpu[1:num_reqs +
2510+
1] = torch.from_numpy(cu_num_tokens)
2511+
25042512
num_computed_tokens_cpu = (
25052513
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
25062514

@@ -2702,6 +2710,7 @@ def _dummy_run(
27022710
max_query_len=max_query_len,
27032711
aclgraph_runtime_mode=aclgraph_runtime_mode,
27042712
force_attention=force_attention,
2713+
num_scheduled_tokens=num_scheduled_tokens,
27052714
)
27062715

27072716
need_dummy_logits = (not self.in_profile_run

0 commit comments

Comments
 (0)