Skip to content

Commit dd622aa

Browse files
ChenCangtaochencangtaopanchao-hubwbigatMengqingCao
authored
[Feature] Support npuhraph_ex backend (#4700)
### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: chencangtao <[email protected]> Signed-off-by: ChenCangtao <[email protected]> Co-authored-by: chencangtao <[email protected]> Co-authored-by: panchao-hub <[email protected]> Co-authored-by: wbigat <[email protected]> Co-authored-by: Mengqing Cao <[email protected]>
1 parent d7db679 commit dd622aa

File tree

7 files changed

+235
-15
lines changed

7 files changed

+235
-15
lines changed

tests/ut/test_ascend_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,21 @@ def test_init_ascend_config_with_additional_config(self):
5757
ascend_config = init_ascend_config(test_vllm_config)
5858
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
5959
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
60+
self.assertFalse(ascend_config.enable_npugraph_ex)
6061

6162
ascend_compilation_config = ascend_config.ascend_compilation_config
6263
self.assertFalse(ascend_compilation_config.enable_quantization_fusion)
6364

65+
@_clean_up_ascend_config
66+
def test_init_ascend_config_enable_npugraph_ex(self):
67+
with self.assertRaises(NotImplementedError):
68+
test_vllm_config = VllmConfig()
69+
test_vllm_config.additional_config = {
70+
"enable_npugraph_ex": True,
71+
"refresh": True,
72+
}
73+
init_ascend_config(test_vllm_config)
74+
6475
@_clean_up_ascend_config
6576
def test_get_ascend_config(self):
6677
test_vllm_config = VllmConfig()

vllm_ascend/ascend_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def __init__(self, vllm_config):
169169
get_flashcomm2_oproj_tp_size_and_validate_config
170170
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
171171
self, vllm_config)
172+
self.enable_npugraph_ex = additional_config.get(
173+
"enable_npugraph_ex", False)
174+
if self.enable_npugraph_ex:
175+
raise NotImplementedError(
176+
"This feature is still in the experiment and will be supported soon."
177+
)
172178
kv_cfg = vllm_config.kv_transfer_config
173179
if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched",
174180
False):

vllm_ascend/compilation/compiler_interface.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import functools
1919
from typing import Any, Callable, Optional
2020

21+
import torch
2122
import torch.fx as fx
2223
from torch._dynamo.backends.common import aot_autograd
2324
from torch._inductor.compile_fx import (graph_returns_tuple,
@@ -26,6 +27,8 @@
2627
from torch.fx import GraphModule
2728
from vllm.compilation.compiler_interface import CompilerInterface
2829

30+
from vllm_ascend.ascend_config import get_ascend_config
31+
2932

3033
def compile_fx(graph: GraphModule, example_inputs: list,
3134
inner_compile: Callable, decompositions: dict) -> Callable:
@@ -39,6 +42,75 @@ def compile_fx(graph: GraphModule, example_inputs: list,
3942
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
4043

4144

45+
def fusion_pass_compile(
46+
graph: fx.GraphModule,
47+
example_inputs: list[Any],
48+
compiler_config: dict[str, Any],
49+
runtime_shape: Optional[int] = None,
50+
key: Optional[str] = None,
51+
) -> tuple[Optional[Callable], Optional[Any]]:
52+
53+
def compile_inner(graph, example_inputs):
54+
current_pass_manager = compiler_config["graph_fusion_manager"]
55+
graph = current_pass_manager(graph, runtime_shape)
56+
return graph
57+
58+
decompositions = select_decomp_table()
59+
60+
compiled_fn = compile_fx(
61+
graph=graph,
62+
example_inputs=example_inputs,
63+
inner_compile=compile_inner,
64+
decompositions=decompositions,
65+
)
66+
67+
return compiled_fn, None
68+
69+
70+
def npugraph_ex_compile(
71+
graph: fx.GraphModule,
72+
example_inputs: list[Any],
73+
compiler_config: dict[str, Any],
74+
runtime_shape: Optional[int] = None,
75+
key: Optional[str] = None,
76+
) -> tuple[Optional[Callable], Optional[Any]]:
77+
# When currently using the FULL_DECODE_ONLY mode,
78+
# the piecewise compilation level slicing process
79+
# in vllm is also encountered.
80+
# This process causes the output to no longer be
81+
# wrapped as a tuple when the fx graph has a single
82+
# output, but torch.compile has a mandatory check.
83+
fx_graph = graph.graph
84+
if not graph_returns_tuple(graph):
85+
output_node = fx_graph.output_node()
86+
with fx_graph.inserting_before(output_node):
87+
return_value = output_node.args[0]
88+
tuple_node = fx_graph.create_node("call_function",
89+
tuple,
90+
args=([return_value], ))
91+
output_node.args = (tuple_node, )
92+
fx_graph.recompile()
93+
94+
import torchair
95+
96+
# TODO: use a better way to lazy register replacement, instead of import one by one
97+
# As an example, we directly import here to register replacement.
98+
import vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant # noqa
99+
100+
torch.npu.set_compile_mode(jit_compile=False)
101+
config = torchair.CompilerConfig()
102+
# use aclgraph mode, avoid the transformation from fx graph to Ascend IR.
103+
config.mode = "reduce-overhead"
104+
# execute FX graph in eager mode before graph mode to optimize FX graph.
105+
config.debug.run_eagerly = True
106+
# static kernel switch, suitable for static shapes or scenes with less shape changes.
107+
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
108+
109+
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
110+
compile_graph = npugraph_ex(graph, example_inputs)
111+
return compile_graph, None
112+
113+
42114
class AscendCompiler(CompilerInterface):
43115
"""
44116
AscendCompiler is a custom compiler interface for the Ascend platform.
@@ -56,18 +128,10 @@ def compile(
56128
key: Optional[str] = None,
57129
) -> tuple[Optional[Callable], Optional[Any]]:
58130

59-
def compile_inner(graph, example_inputs):
60-
current_pass_manager = compiler_config["graph_fusion_manager"]
61-
graph = current_pass_manager(graph, runtime_shape)
62-
return graph
63-
64-
decompositions = select_decomp_table()
65-
66-
compiled_fn = compile_fx(
67-
graph=graph,
68-
example_inputs=example_inputs,
69-
inner_compile=compile_inner,
70-
decompositions=decompositions,
71-
)
72-
73-
return compiled_fn, None
131+
ascend_config = get_ascend_config()
132+
if ascend_config.enable_npugraph_ex:
133+
return npugraph_ex_compile(graph, example_inputs, compiler_config,
134+
runtime_shape, key)
135+
else:
136+
return fusion_pass_compile(graph, example_inputs, compiler_config,
137+
runtime_shape, key)

vllm_ascend/compilation/npugraph_ex_passes/__init__.py

Whitespace-only changes.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
import functools
19+
import sys
20+
21+
import torch
22+
from torch._inductor.pattern_matcher import Match
23+
from vllm.logger import logger
24+
25+
26+
@functools.lru_cache(None)
27+
# The replacement registered here will be actually executed after AOT.
28+
def _register_replacement(epsilon):
29+
if 'torch_npu' not in sys.modules:
30+
logger.info(
31+
'The AddRMSNormQuant fusion will only be enabled in a torch npu env.'
32+
'When there is no torch_npu in the env, skip fusion.')
33+
return
34+
35+
def _extra_stream_scope_check(match: Match) -> bool:
36+
"""
37+
Checks if all nodes in the same stream.
38+
"""
39+
non_default_streams = set()
40+
has_default = False
41+
42+
for node in match.nodes:
43+
if node.op == "call_function":
44+
current_stream = node.meta.get("stream_label")
45+
if current_stream is None:
46+
has_default = True
47+
else:
48+
non_default_streams.add(current_stream)
49+
if len(non_default_streams) > 1:
50+
logger.debug(
51+
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
52+
f"Multiple streams found: {non_default_streams}. "
53+
f"Fusion is not supported for cross-stream operations."
54+
)
55+
return False
56+
57+
if has_default and len(non_default_streams) > 0:
58+
logger.debug(
59+
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
60+
f"Multiple streams found: {non_default_streams}. "
61+
f"Fusion is not supported for cross-stream operations.")
62+
return False
63+
64+
return True
65+
66+
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
67+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
68+
offset: torch.Tensor):
69+
"""
70+
Pattern for AddRMSNormQuant fusion.
71+
"""
72+
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
73+
rms_norm_weight, epsilon)
74+
out0 = output[0]
75+
out1 = output[2]
76+
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
77+
torch.qint8, -1, False)
78+
return quantized_output, out1
79+
80+
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
81+
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
82+
offset: torch.Tensor):
83+
"""
84+
Replacement for the AddRMSNormQuant fusion.
85+
"""
86+
output = torch.ops.npu.npu_add_rms_norm_quant(
87+
rms_norm_input,
88+
residual,
89+
rms_norm_weight,
90+
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
91+
1. / scale,
92+
offset,
93+
epsilon=epsilon)
94+
quantized_output = output[0]
95+
out1 = output[2]
96+
return quantized_output, out1
97+
98+
def get_inputs():
99+
"""
100+
Generate example inputs for the AddRMSNormQuant fusion pattern.
101+
"""
102+
rms_norm_input = torch.randn(2, 4, device="npu")
103+
residual = torch.randn(2, 4, device="npu")
104+
rms_norm_weight = torch.randn(4, device="npu")
105+
scale = torch.tensor([1.0], device="npu")
106+
offset = torch.tensor([0.0], device="npu")
107+
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
108+
109+
import torchair
110+
111+
torchair.register_replacement(search_fn=pattern,
112+
replace_fn=replacement,
113+
example_inputs=get_inputs(),
114+
extra_check=_extra_stream_scope_check)
115+
116+
117+
# register converter for pass
118+
common_epsilons = [1e-5, 1e-6]
119+
for eps in common_epsilons:
120+
logger.info(
121+
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
122+
)
123+
_register_replacement(eps)

vllm_ascend/platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
231231

232232
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
233233
compilation_config.mode = CompilationMode.NONE
234+
ascend_config.enable_npugraph_ex = False
234235
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
235236
logger.info(
236237
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
@@ -241,12 +242,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
241242
compilation_config.use_inductor = False
242243
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
243244
update_aclgraph_sizes(vllm_config)
245+
ascend_config.enable_npugraph_ex = False
244246
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
245247
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
246248
logger.info(
247249
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
248250
"using only ACL Graph mode")
249251
compilation_config.use_inductor = False
252+
compilation_config.splitting_ops = []
250253
warning_message = """\033[91m
251254
**********************************************************************************
252255
* WARNING: You have enabled the *full graph* feature.
@@ -266,6 +269,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
266269
compilation_config.cudagraph_mode)
267270
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
268271
compilation_config.mode = CompilationMode.NONE
272+
ascend_config.enable_npugraph_ex = False
269273

270274
# TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1
271275
# Then, we will have to discuss the error handling strategy and user experience

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,6 +2575,12 @@ def execute_model(
25752575
self.debugger.stop()
25762576
self.debugger.step()
25772577
return pool_output
2578+
# Sometimes, after the model is compiled through the AOT backend,
2579+
# the model output may become a list containing only one Tensor object.
2580+
if isinstance(hidden_states, list) and \
2581+
len(hidden_states) == 1 and \
2582+
isinstance(hidden_states[0], torch.Tensor):
2583+
hidden_states = hidden_states[0]
25782584
sample_hidden_states = hidden_states[logits_indices]
25792585
logits = self.model.compute_logits(sample_hidden_states)
25802586
if broadcast_pp_output:
@@ -3296,6 +3302,12 @@ def profile_run(self) -> None:
32963302
dtype=np.int32)
32973303
logit_indices = np.cumsum(num_scheduled_tokens) - 1
32983304
# TODO: need to rum a dummy sampler for generate task
3305+
# Sometimes, after the model is compiled through the AOT backend,
3306+
# the model output may become a list containing only one Tensor object.
3307+
if isinstance(hidden_states, list) and \
3308+
len(hidden_states) == 1 and \
3309+
isinstance(hidden_states[0], torch.Tensor):
3310+
hidden_states = hidden_states[0]
32993311
hidden_states = hidden_states[logit_indices]
33003312
output = self.model.compute_logits(hidden_states)
33013313

0 commit comments

Comments
 (0)