Skip to content

Commit 5bdef8e

Browse files
committed
[Fusion] normalize fusion naming and enable e2e test
Signed-off-by: wxsIcey <[email protected]>
1 parent 2b819bb commit 5bdef8e

File tree

7 files changed

+17
-18
lines changed

7 files changed

+17
-18
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ jobs:
104104
pytest -sv tests/e2e/singlecard/test_vlm.py
105105
pytest -sv tests/e2e/singlecard/multi-modal/test_internvl.py
106106
pytest -sv tests/e2e/singlecard/test_xlite.py
107+
pytest -sv tests/e2e/singlecard/test_norm_quant_fusion.py
107108
108109
# ------------------------------------ v1 spec decode test ------------------------------------ #
109110
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

tests/e2e/singlecard/test_quant_fusion.py renamed to tests/e2e/singlecard/test_norm_quant_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
2929

3030
from vllm_ascend.compilation.compiler_interface import compile_fx
31-
from vllm_ascend.compilation.passes.quant_fusion_pass import \
31+
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
3232
AddRMSNormQuantFusionPass
3333

3434

tests/ut/test_ascend_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_init_ascend_config_without_additional_config(self):
5757
self.assertFalse(torchair_graph_config.enable_kv_nz)
5858

5959
ascend_compilation_config = ascend_config.ascend_compilation_config
60-
self.assertTrue(ascend_compilation_config.enable_quantization_fusion)
60+
self.assertTrue(ascend_compilation_config.fuse_norm_quant)
6161

6262
@_clean_up_ascend_config
6363
def test_init_ascend_config_with_additional_config(self):
@@ -74,7 +74,7 @@ def test_init_ascend_config_with_additional_config(self):
7474
"enable_kv_nz": True
7575
},
7676
"ascend_compilation_config": {
77-
"enable_quantization_fusion": False,
77+
"fuse_norm_quant": False,
7878
},
7979
"multistream_overlap_shared_expert": True,
8080
"expert_map_path": "test_expert_map_path",
@@ -94,7 +94,7 @@ def test_init_ascend_config_with_additional_config(self):
9494
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
9595
self.assertTrue(torchair_graph_config.enable_kv_nz)
9696
ascend_compilation_config = ascend_config.ascend_compilation_config
97-
self.assertFalse(ascend_compilation_config.enable_quantization_fusion)
97+
self.assertFalse(ascend_compilation_config.fuse_norm_quant)
9898

9999
@_clean_up_ascend_config
100100
def test_init_ascend_config_with_refresh(self):

vllm_ascend/ascend_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,18 @@ class AscendCompilationConfig:
199199
deployed on Ascend platforms.
200200
"""
201201

202-
def __init__(self, enable_quantization_fusion: bool = True, **kwargs):
202+
def __init__(self, fuse_norm_quant: bool = True, **kwargs):
203203
"""
204204
Initialize the configuration.
205205
206206
Args:
207-
enable_quantization_fusion (bool): Whether to enable quantization fusion optimization.
208-
When set to True, the system will optimize quantization-related operations,
209-
reducing the number of quantization/dequantization nodes.
207+
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
208+
When set to True, the system will optimize norm and quant operations.
210209
Default: True
211210
212211
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
213212
"""
214-
self.enable_quantization_fusion = enable_quantization_fusion
213+
self.fuse_norm_quant = fuse_norm_quant
215214
# Add more compilation related configs here as needed
216215

217216

@@ -406,9 +405,9 @@ def check_ascend_config(vllm_config, enforce_eager):
406405
"it has been disabled automatically.")
407406
# aclgraph case
408407
else:
409-
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
408+
if ascend_config.ascend_compilation_config.fuse_norm_quant:
410409
logger.info(
411-
"Quantization fusion enabled! op fusion on quantization are expected. "
410+
"Norm and Quant fusion enabled! op fusion on norm and quant are expected. "
412411
)
413412

414413
if vllm_config.model_config:

vllm_ascend/compilation/graph_fusion_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def configure(self, config: VllmConfig):
4646
# By default, we enable the graph fusion and quantization fusion pass.
4747
self.ascend_compilation_config: dict = config.additional_config.get(
4848
"ascend_compilation_config", {})
49-
if self.ascend_compilation_config.get("enable_quantization_fusion",
50-
True):
51-
from .passes.quant_fusion_pass import AddRMSNormQuantFusionPass
49+
if self.ascend_compilation_config.get("fuse_norm_quant", True):
50+
from .passes.norm_quant_fusion_pass import \
51+
AddRMSNormQuantFusionPass
5252
self.passes.append(AddRMSNormQuantFusionPass(config))
5353
# Add more passes here as needed

vllm_ascend/compilation/passes/quant_fusion_pass.py renamed to vllm_ascend/compilation/passes/norm_quant_fusion_pass.py

File renamed without changes.

vllm_ascend/platform.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ def get_compile_backend(self) -> str:
9191
Get the custom compile backend. Previously, we used EagerAdaptor by default.
9292
To use graph fusion operations, we defined our own backend compiler.
9393
"""
94-
from vllm_ascend.compilation.compiler_interface import AscendCompiler
95-
return AscendCompiler.__module__ + "." + AscendCompiler.__name__
94+
return "vllm_ascend.compilation.compiler_interface.AscendCompiler"
9695

9796
@classmethod
9897
def pre_register_and_update(cls,
@@ -248,8 +247,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
248247
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
249248
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
250249

251-
from vllm_ascend.compilation.compiler_interface import AscendCompiler
252-
compilation_config.oot_compiler = AscendCompiler.__module__ + "." + AscendCompiler.__name__
250+
# get custom compile backend for graph fusion
251+
compilation_config.oot_compiler = cls.get_compile_backend()
253252

254253
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
255254
compilation_config.mode = CompilationMode.NONE

0 commit comments

Comments
 (0)