Skip to content

Commit c19152f

Browse files
committed
refactor compile e2e test
Signed-off-by: wxsIcey <[email protected]>
1 parent 5bdef8e commit c19152f

File tree

3 files changed

+115
-101
lines changed

3 files changed

+115
-101
lines changed

tests/e2e/singlecard/test_norm_quant_fusion.py renamed to tests/e2e/singlecard/compile/backend.py

Lines changed: 4 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -16,67 +16,11 @@
1616
#
1717
from copy import deepcopy
1818
from typing import Any, Callable, List, Optional, Sequence
19-
20-
import pytest
21-
import torch
2219
import torch.fx as fx
23-
import torch.nn as nn
24-
import torch_npu
25-
import vllm.config
2620
from torch._inductor.decomposition import select_decomp_table
2721
from vllm.compilation.fx_utils import OpOverload
28-
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
29-
22+
from vllm.config import get_current_vllm_config
3023
from vllm_ascend.compilation.compiler_interface import compile_fx
31-
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
32-
AddRMSNormQuantFusionPass
33-
34-
35-
class TestModel(nn.Module):
36-
"""
37-
A minimal test model that simulates the pattern:
38-
AddRMSNorm → Quantization
39-
"""
40-
41-
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
42-
super().__init__()
43-
self.hidden_size = hidden_size
44-
self.eps = eps
45-
self.rms_norm_weight = nn.Parameter(
46-
torch.randn(hidden_size, device=device))
47-
self.quant_scale = torch.tensor([1.0], device=device)
48-
self.quant_offset = torch.tensor([0.0], device=device)
49-
50-
def forward(self, x):
51-
"""
52-
Forward pass:
53-
1. Perform npu_add_rms_norm
54-
2. Quantize the normalized output to int8
55-
Returns both quantized output and updated residual.
56-
"""
57-
residual = torch.zeros_like(x)
58-
59-
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
60-
x, residual, self.rms_norm_weight, self.eps)
61-
62-
quantized_output = torch_npu.npu_quantize(norm_output,
63-
self.quant_scale,
64-
self.quant_offset,
65-
torch.qint8, -1, False)
66-
67-
return quantized_output, new_residual
68-
69-
def ops_in_model_before(self) -> List[OpOverload]:
70-
"""Return the list of expected operators BEFORE fusion."""
71-
return [
72-
torch.ops.npu.npu_add_rms_norm.default,
73-
torch.ops.npu.npu_quantize.default
74-
]
75-
76-
def ops_in_model_after(self) -> List[OpOverload]:
77-
"""Return the list of expected operators AFTER successful fusion."""
78-
return [torch.ops.npu.npu_add_rms_norm_quant.default]
79-
8024

8125
class TestBackend:
8226
"""
@@ -85,14 +29,12 @@ class TestBackend:
8529
records the FX graph before and after the transformation.
8630
"""
8731

88-
def __init__(self):
32+
def __init__(self, custom_passes: Optional[List[Any]] = None):
8933
vllm_config = get_current_vllm_config()
9034
compile_config = vllm_config.compilation_config
91-
self.custom_passes = [
92-
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
93-
]
9435
self.inductor_config = compile_config.inductor_compile_config
9536
self.inductor_config["graph_fusion_manager"] = self.post_pass
37+
self.custom_passes = custom_passes
9638

9739
# Placeholders to store FX graphs for verification
9840
self.graph_pre_pass = None
@@ -179,41 +121,4 @@ def check_after_ops(self, ops: Sequence[OpOverload]):
179121
for op in ops:
180122
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
181123
print(f"Op {op}: post={num_post}")
182-
assert num_post > 0, f"Op {op} not found in post-pass graph"
183-
184-
185-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
186-
@pytest.mark.parametrize("hidden_size", [64])
187-
@pytest.mark.parametrize("num_tokens", [257])
188-
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
189-
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
190-
num_tokens: int, eps: float):
191-
"""
192-
End-to-end test for AddRMSNorm+Quantize fusion.
193-
Compares: Operator presence/absence before and after graph transformation
194-
"""
195-
torch.set_default_dtype(dtype)
196-
torch.manual_seed(1)
197-
198-
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
199-
200-
with vllm.config.set_current_vllm_config(vllm_config):
201-
backend = TestBackend()
202-
model = TestModel(hidden_size, eps, device="npu")
203-
model = model.to("npu")
204-
205-
x = torch.rand(num_tokens,
206-
hidden_size,
207-
device="npu",
208-
dtype=dtype,
209-
requires_grad=False)
210-
211-
result_unfused = model(x)
212-
print("Unfused result:", [t.shape for t in result_unfused])
213-
model_fused = torch.compile(model, backend=backend)
214-
result_fused = model_fused(x)
215-
print("Fused result:", [t.shape for t in result_fused])
216-
217-
print("=== Checking operator fusion ===")
218-
backend.check_before_ops(model.ops_in_model_before())
219-
backend.check_after_ops(model.ops_in_model_after())
124+
assert num_post > 0, f"Op {op} not found in post-pass graph"
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from typing import List
18+
19+
import pytest
20+
import torch
21+
import torch.nn as nn
22+
import torch_npu
23+
import vllm.config
24+
from vllm.config import VllmConfig, ModelConfig
25+
from vllm.compilation.fx_utils import OpOverload
26+
from tests.e2e.singlecard.compile.backend import TestBackend
27+
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
28+
29+
30+
class TestModel(nn.Module):
31+
"""
32+
A minimal test model that simulates the pattern:
33+
AddRMSNorm → Quantization
34+
"""
35+
36+
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
37+
super().__init__()
38+
self.hidden_size = hidden_size
39+
self.eps = eps
40+
self.rms_norm_weight = nn.Parameter(
41+
torch.randn(hidden_size, device=device))
42+
self.quant_scale = torch.tensor([1.0], device=device)
43+
self.quant_offset = torch.tensor([0.0], device=device)
44+
45+
def forward(self, x):
46+
"""
47+
Forward pass:
48+
1. Perform npu_add_rms_norm
49+
2. Quantize the normalized output to int8
50+
Returns both quantized output and updated residual.
51+
"""
52+
residual = torch.zeros_like(x)
53+
54+
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
55+
x, residual, self.rms_norm_weight, self.eps)
56+
57+
quantized_output = torch_npu.npu_quantize(norm_output,
58+
self.quant_scale,
59+
self.quant_offset,
60+
torch.qint8, -1, False)
61+
62+
return quantized_output, new_residual
63+
64+
def ops_in_model_before(self) -> List[OpOverload]:
65+
"""Return the list of expected operators BEFORE fusion."""
66+
return [
67+
torch.ops.npu.npu_add_rms_norm.default,
68+
torch.ops.npu.npu_quantize.default
69+
]
70+
71+
def ops_in_model_after(self) -> List[OpOverload]:
72+
"""Return the list of expected operators AFTER successful fusion."""
73+
return [torch.ops.npu.npu_add_rms_norm_quant.default]
74+
75+
76+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
77+
@pytest.mark.parametrize("hidden_size", [64])
78+
@pytest.mark.parametrize("num_tokens", [257])
79+
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
80+
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
81+
num_tokens: int, eps: float):
82+
"""
83+
End-to-end test for AddRMSNorm+Quantize fusion.
84+
Compares: Operator presence/absence before and after graph transformation
85+
"""
86+
torch.set_default_dtype(dtype)
87+
torch.manual_seed(1)
88+
89+
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
90+
91+
with vllm.config.set_current_vllm_config(vllm_config):
92+
backend = TestBackend(custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
93+
model = TestModel(hidden_size, eps, device="npu")
94+
model = model.to("npu")
95+
96+
x = torch.rand(num_tokens,
97+
hidden_size,
98+
device="npu",
99+
dtype=dtype,
100+
requires_grad=False)
101+
102+
result_unfused = model(x)
103+
print("Unfused result:", [t.shape for t in result_unfused])
104+
model_fused = torch.compile(model, backend=backend)
105+
result_fused = model_fused(x)
106+
print("Fused result:", [t.shape for t in result_fused])
107+
108+
print("=== Checking operator fusion ===")
109+
backend.check_before_ops(model.ops_in_model_before())
110+
backend.check_after_ops(model.ops_in_model_after())

vllm_ascend/compilation/graph_fusion_pass_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def configure(self, config: VllmConfig):
4747
self.ascend_compilation_config: dict = config.additional_config.get(
4848
"ascend_compilation_config", {})
4949
if self.ascend_compilation_config.get("fuse_norm_quant", True):
50-
from .passes.norm_quant_fusion_pass import \
51-
AddRMSNormQuantFusionPass
50+
from .passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
5251
self.passes.append(AddRMSNormQuantFusionPass(config))
5352
# Add more passes here as needed

0 commit comments

Comments
 (0)