Skip to content

Commit 18221c0

Browse files
authored
[Fusion] normalize fusion naming and enable e2e test (#4693)
### What this PR does / why we need it? This PR standardizes the fusion naming, changing `enable_quantization_fusion` to `fuse_norm_quant`, and enables e2e testing. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: wxsIcey <[email protected]>
1 parent 07c7131 commit 18221c0

File tree

8 files changed

+136
-113
lines changed

8 files changed

+136
-113
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ jobs:
103103
pytest -sv tests/e2e/singlecard/test_vlm.py
104104
pytest -sv tests/e2e/singlecard/test_xlite.py
105105
pytest -sv tests/e2e/singlecard/pooling/
106+
pytest -sv tests/e2e/singlecard/compile/test_norm_quant_fusion.py
106107
107108
# ------------------------------------ v1 spec decode test ------------------------------------ #
108109
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
Lines changed: 9 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,65 +17,12 @@
1717
from copy import deepcopy
1818
from typing import Any, Callable, List, Optional, Sequence
1919

20-
import pytest
21-
import torch
2220
import torch.fx as fx
23-
import torch.nn as nn
24-
import torch_npu
25-
import vllm.config
2621
from torch._inductor.decomposition import select_decomp_table
2722
from vllm.compilation.fx_utils import OpOverload
28-
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
23+
from vllm.config import get_current_vllm_config
2924

3025
from vllm_ascend.compilation.compiler_interface import compile_fx
31-
from vllm_ascend.compilation.passes.quant_fusion_pass import \
32-
AddRMSNormQuantFusionPass
33-
34-
35-
class TestModel(nn.Module):
36-
"""
37-
A minimal test model that simulates the pattern:
38-
AddRMSNorm → Quantization
39-
"""
40-
41-
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
42-
super().__init__()
43-
self.hidden_size = hidden_size
44-
self.eps = eps
45-
self.rms_norm_weight = nn.Parameter(
46-
torch.randn(hidden_size, device=device))
47-
self.quant_scale = torch.tensor([1.0], device=device)
48-
self.quant_offset = torch.tensor([0.0], device=device)
49-
50-
def forward(self, x):
51-
"""
52-
Forward pass:
53-
1. Perform npu_add_rms_norm
54-
2. Quantize the normalized output to int8
55-
Returns both quantized output and updated residual.
56-
"""
57-
residual = torch.zeros_like(x)
58-
59-
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
60-
x, residual, self.rms_norm_weight, self.eps)
61-
62-
quantized_output = torch_npu.npu_quantize(norm_output,
63-
self.quant_scale,
64-
self.quant_offset,
65-
torch.qint8, -1, False)
66-
67-
return quantized_output, new_residual
68-
69-
def ops_in_model_before(self) -> List[OpOverload]:
70-
"""Return the list of expected operators BEFORE fusion."""
71-
return [
72-
torch.ops.npu.npu_add_rms_norm.default,
73-
torch.ops.npu.npu_quantize.default
74-
]
75-
76-
def ops_in_model_after(self) -> List[OpOverload]:
77-
"""Return the list of expected operators AFTER successful fusion."""
78-
return [torch.ops.npu.npu_add_rms_norm_quant.default]
7926

8027

8128
class TestBackend:
@@ -85,14 +32,12 @@ class TestBackend:
8532
records the FX graph before and after the transformation.
8633
"""
8734

88-
def __init__(self):
35+
def __init__(self, custom_passes: Optional[List[Any]] = None):
8936
vllm_config = get_current_vllm_config()
9037
compile_config = vllm_config.compilation_config
91-
self.custom_passes = [
92-
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
93-
]
9438
self.inductor_config = compile_config.inductor_compile_config
9539
self.inductor_config["graph_fusion_manager"] = self.post_pass
40+
self.custom_passes = custom_passes
9641

9742
# Placeholders to store FX graphs for verification
9843
self.graph_pre_pass = None
@@ -105,8 +50,9 @@ def post_pass(self,
10550
Apply custom graph transformation passes.
10651
"""
10752
self.graph_pre_pass = deepcopy(graph)
108-
for pass_ in self.custom_passes:
109-
pass_(graph)
53+
if self.custom_passes is not None:
54+
for pass_ in self.custom_passes:
55+
pass_(graph)
11056
self.graph_post_pass = deepcopy(graph)
11157
return graph
11258

@@ -136,11 +82,13 @@ def compile_inner(graph, example_inputs):
13682
)
13783
return compiled_fn, None
13884

139-
def __call__(self, gm: fx.GraphModule, example_inputs: List[Any]):
85+
def __call__(self, gm: fx.GraphModule,
86+
example_inputs: Optional[List[Any]]):
14087
"""
14188
Make the backend callable by torch.compile().
14289
Returns a compiled executable function.
14390
"""
91+
assert example_inputs is not None
14492
compiled_fn, _ = self.compile(
14593
gm,
14694
example_inputs,
@@ -180,40 +128,3 @@ def check_after_ops(self, ops: Sequence[OpOverload]):
180128
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
181129
print(f"Op {op}: post={num_post}")
182130
assert num_post > 0, f"Op {op} not found in post-pass graph"
183-
184-
185-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
186-
@pytest.mark.parametrize("hidden_size", [64])
187-
@pytest.mark.parametrize("num_tokens", [257])
188-
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
189-
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
190-
num_tokens: int, eps: float):
191-
"""
192-
End-to-end test for AddRMSNorm+Quantize fusion.
193-
Compares: Operator presence/absence before and after graph transformation
194-
"""
195-
torch.set_default_dtype(dtype)
196-
torch.manual_seed(1)
197-
198-
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
199-
200-
with vllm.config.set_current_vllm_config(vllm_config):
201-
backend = TestBackend()
202-
model = TestModel(hidden_size, eps, device="npu")
203-
model = model.to("npu")
204-
205-
x = torch.rand(num_tokens,
206-
hidden_size,
207-
device="npu",
208-
dtype=dtype,
209-
requires_grad=False)
210-
211-
result_unfused = model(x)
212-
print("Unfused result:", [t.shape for t in result_unfused])
213-
model_fused = torch.compile(model, backend=backend)
214-
result_fused = model_fused(x)
215-
print("Fused result:", [t.shape for t in result_fused])
216-
217-
print("=== Checking operator fusion ===")
218-
backend.check_before_ops(model.ops_in_model_before())
219-
backend.check_after_ops(model.ops_in_model_after())
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from typing import List
18+
19+
import pytest
20+
import torch
21+
import torch.nn as nn
22+
import torch_npu
23+
import vllm.config
24+
from vllm.compilation.fx_utils import OpOverload
25+
from vllm.config import ModelConfig, VllmConfig
26+
27+
from tests.e2e.singlecard.compile.backend import TestBackend
28+
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
29+
AddRMSNormQuantFusionPass
30+
31+
32+
class TestModel(nn.Module):
33+
"""
34+
A minimal test model that simulates the pattern:
35+
AddRMSNorm → Quantization
36+
"""
37+
38+
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
39+
super().__init__()
40+
self.hidden_size = hidden_size
41+
self.eps = eps
42+
self.rms_norm_weight = nn.Parameter(
43+
torch.randn(hidden_size, device=device))
44+
self.quant_scale = torch.tensor([1.0], device=device)
45+
self.quant_offset = torch.tensor([0.0], device=device)
46+
47+
def forward(self, x):
48+
"""
49+
Forward pass:
50+
1. Perform npu_add_rms_norm
51+
2. Quantize the normalized output to int8
52+
Returns both quantized output and updated residual.
53+
"""
54+
residual = torch.zeros_like(x)
55+
56+
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
57+
x, residual, self.rms_norm_weight, self.eps)
58+
59+
quantized_output = torch_npu.npu_quantize(norm_output,
60+
self.quant_scale,
61+
self.quant_offset,
62+
torch.qint8, -1, False)
63+
64+
return quantized_output, new_residual
65+
66+
def ops_in_model_before(self) -> List[OpOverload]:
67+
"""Return the list of expected operators BEFORE fusion."""
68+
return [
69+
torch.ops.npu.npu_add_rms_norm.default,
70+
torch.ops.npu.npu_quantize.default
71+
]
72+
73+
def ops_in_model_after(self) -> List[OpOverload]:
74+
"""Return the list of expected operators AFTER successful fusion."""
75+
return [torch.ops.npu.npu_add_rms_norm_quant.default]
76+
77+
78+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
79+
@pytest.mark.parametrize("hidden_size", [64])
80+
@pytest.mark.parametrize("num_tokens", [257])
81+
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
82+
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
83+
num_tokens: int, eps: float):
84+
"""
85+
End-to-end test for AddRMSNorm+Quantize fusion.
86+
Compares: Operator presence/absence before and after graph transformation
87+
"""
88+
torch.set_default_dtype(dtype)
89+
torch.manual_seed(1)
90+
91+
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
92+
93+
with vllm.config.set_current_vllm_config(vllm_config):
94+
backend = TestBackend(
95+
custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
96+
model = TestModel(hidden_size, eps, device="npu")
97+
model = model.to("npu")
98+
99+
x = torch.rand(num_tokens,
100+
hidden_size,
101+
device="npu",
102+
dtype=dtype,
103+
requires_grad=False)
104+
105+
result_unfused = model(x)
106+
print("Unfused result:", [t.shape for t in result_unfused])
107+
model_fused = torch.compile(model, backend=backend)
108+
result_fused = model_fused(x)
109+
print("Fused result:", [t.shape for t in result_fused])
110+
111+
print("=== Checking operator fusion ===")
112+
backend.check_before_ops(model.ops_in_model_before())
113+
backend.check_after_ops(model.ops_in_model_after())

tests/ut/test_ascend_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ def test_init_ascend_config_without_additional_config(self):
4141
self.assertFalse(ascend_config.multistream_overlap_shared_expert)
4242

4343
ascend_compilation_config = ascend_config.ascend_compilation_config
44-
self.assertTrue(ascend_compilation_config.enable_quantization_fusion)
44+
self.assertTrue(ascend_compilation_config.fuse_norm_quant)
4545

4646
@_clean_up_ascend_config
4747
def test_init_ascend_config_with_additional_config(self):
4848
test_vllm_config = VllmConfig()
4949
test_vllm_config.additional_config = {
5050
"ascend_compilation_config": {
51-
"enable_quantization_fusion": False,
51+
"fuse_norm_quant": False,
5252
},
5353
"multistream_overlap_shared_expert": True,
5454
"expert_map_path": "test_expert_map_path",
@@ -60,7 +60,7 @@ def test_init_ascend_config_with_additional_config(self):
6060
self.assertFalse(ascend_config.enable_npugraph_ex)
6161

6262
ascend_compilation_config = ascend_config.ascend_compilation_config
63-
self.assertFalse(ascend_compilation_config.enable_quantization_fusion)
63+
self.assertFalse(ascend_compilation_config.fuse_norm_quant)
6464

6565
@_clean_up_ascend_config
6666
def test_init_ascend_config_enable_npugraph_ex(self):

vllm_ascend/ascend_config.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,19 +190,18 @@ class AscendCompilationConfig:
190190
deployed on Ascend platforms.
191191
"""
192192

193-
def __init__(self, enable_quantization_fusion: bool = True, **kwargs):
193+
def __init__(self, fuse_norm_quant: bool = True, **kwargs):
194194
"""
195195
Initialize the configuration.
196196
197197
Args:
198-
enable_quantization_fusion (bool): Whether to enable quantization fusion optimization.
199-
When set to True, the system will optimize quantization-related operations,
200-
reducing the number of quantization/dequantization nodes.
198+
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
199+
When set to True, the system will optimize norm and quant operations.
201200
Default: True
202201
203202
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
204203
"""
205-
self.enable_quantization_fusion = enable_quantization_fusion
204+
self.fuse_norm_quant = fuse_norm_quant
206205
# Add more compilation related configs here as needed
207206

208207

vllm_ascend/compilation/graph_fusion_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def configure(self, config: VllmConfig):
4646
# By default, we enable the graph fusion and quantization fusion pass.
4747
self.ascend_compilation_config: dict = config.additional_config.get(
4848
"ascend_compilation_config", {})
49-
if self.ascend_compilation_config.get("enable_quantization_fusion",
50-
True):
51-
from .passes.quant_fusion_pass import AddRMSNormQuantFusionPass
49+
if self.ascend_compilation_config.get("fuse_norm_quant", True):
50+
from .passes.norm_quant_fusion_pass import \
51+
AddRMSNormQuantFusionPass
5252
self.passes.append(AddRMSNormQuantFusionPass(config))
5353
# Add more passes here as needed

vllm_ascend/compilation/passes/quant_fusion_pass.py renamed to vllm_ascend/compilation/passes/norm_quant_fusion_pass.py

File renamed without changes.

vllm_ascend/platform.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ def get_compile_backend(self) -> str:
8888
Get the custom compile backend. Previously, we used EagerAdaptor by default.
8989
To use graph fusion operations, we defined our own backend compiler.
9090
"""
91-
from vllm_ascend.compilation.compiler_interface import AscendCompiler
92-
return AscendCompiler.__module__ + "." + AscendCompiler.__name__
91+
return "vllm_ascend.compilation.compiler_interface.AscendCompiler"
9392

9493
@classmethod
9594
def pre_register_and_update(cls,
@@ -225,8 +224,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
225224
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
226225
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
227226

228-
from vllm_ascend.compilation.compiler_interface import AscendCompiler
229-
compilation_config.oot_compiler = AscendCompiler.__module__ + "." + AscendCompiler.__name__
227+
# get custom compile backend for graph fusion
228+
compilation_config.oot_compiler = cls.get_compile_backend()
230229

231230
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
232231
compilation_config.mode = CompilationMode.NONE

0 commit comments

Comments
 (0)