Skip to content

Commit 7661d55

Browse files
committed
tiny fix
Signed-off-by: wxsIcey <[email protected]>
1 parent c19152f commit 7661d55

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

tests/e2e/singlecard/compile/backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
#
1717
from copy import deepcopy
1818
from typing import Any, Callable, List, Optional, Sequence
19+
1920
import torch.fx as fx
2021
from torch._inductor.decomposition import select_decomp_table
2122
from vllm.compilation.fx_utils import OpOverload
2223
from vllm.config import get_current_vllm_config
24+
2325
from vllm_ascend.compilation.compiler_interface import compile_fx
2426

27+
2528
class TestBackend:
2629
"""
2730
A custom compilation backend for testing operator fusion passes.
@@ -121,4 +124,4 @@ def check_after_ops(self, ops: Sequence[OpOverload]):
121124
for op in ops:
122125
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
123126
print(f"Op {op}: post={num_post}")
124-
assert num_post > 0, f"Op {op} not found in post-pass graph"
127+
assert num_post > 0, f"Op {op} not found in post-pass graph"

tests/e2e/singlecard/compile/test_norm_quant_fusion.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
import torch.nn as nn
2222
import torch_npu
2323
import vllm.config
24-
from vllm.config import VllmConfig, ModelConfig
2524
from vllm.compilation.fx_utils import OpOverload
25+
from vllm.config import ModelConfig, VllmConfig
26+
2627
from tests.e2e.singlecard.compile.backend import TestBackend
27-
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
28+
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
29+
AddRMSNormQuantFusionPass
2830

2931

3032
class TestModel(nn.Module):
@@ -89,7 +91,8 @@ def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
8991
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
9092

9193
with vllm.config.set_current_vllm_config(vllm_config):
92-
backend = TestBackend(custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
94+
backend = TestBackend(
95+
custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
9396
model = TestModel(hidden_size, eps, device="npu")
9497
model = model.to("npu")
9598

vllm_ascend/compilation/graph_fusion_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def configure(self, config: VllmConfig):
4747
self.ascend_compilation_config: dict = config.additional_config.get(
4848
"ascend_compilation_config", {})
4949
if self.ascend_compilation_config.get("fuse_norm_quant", True):
50-
from .passes.norm_quant_fusion_pass import AddRMSNormQuantFusionPass
50+
from .passes.norm_quant_fusion_pass import \
51+
AddRMSNormQuantFusionPass
5152
self.passes.append(AddRMSNormQuantFusionPass(config))
5253
# Add more passes here as needed

0 commit comments

Comments
 (0)