1616#
1717from copy import deepcopy
1818from typing import Any , Callable , List , Optional , Sequence
19-
20- import pytest
21- import torch
2219import torch .fx as fx
23- import torch .nn as nn
24- import torch_npu
25- import vllm .config
2620from torch ._inductor .decomposition import select_decomp_table
2721from vllm .compilation .fx_utils import OpOverload
28- from vllm .config import ModelConfig , VllmConfig , get_current_vllm_config
29-
22+ from vllm .config import get_current_vllm_config
3023from vllm_ascend .compilation .compiler_interface import compile_fx
31- from vllm_ascend .compilation .passes .norm_quant_fusion_pass import \
32- AddRMSNormQuantFusionPass
33-
34-
35- class TestModel (nn .Module ):
36- """
37- A minimal test model that simulates the pattern:
38- AddRMSNorm → Quantization
39- """
40-
41- def __init__ (self , hidden_size : int , eps : float = 1e-6 , device = "npu" ):
42- super ().__init__ ()
43- self .hidden_size = hidden_size
44- self .eps = eps
45- self .rms_norm_weight = nn .Parameter (
46- torch .randn (hidden_size , device = device ))
47- self .quant_scale = torch .tensor ([1.0 ], device = device )
48- self .quant_offset = torch .tensor ([0.0 ], device = device )
49-
50- def forward (self , x ):
51- """
52- Forward pass:
53- 1. Perform npu_add_rms_norm
54- 2. Quantize the normalized output to int8
55- Returns both quantized output and updated residual.
56- """
57- residual = torch .zeros_like (x )
58-
59- norm_output , _ , new_residual = torch_npu .npu_add_rms_norm (
60- x , residual , self .rms_norm_weight , self .eps )
61-
62- quantized_output = torch_npu .npu_quantize (norm_output ,
63- self .quant_scale ,
64- self .quant_offset ,
65- torch .qint8 , - 1 , False )
66-
67- return quantized_output , new_residual
68-
69- def ops_in_model_before (self ) -> List [OpOverload ]:
70- """Return the list of expected operators BEFORE fusion."""
71- return [
72- torch .ops .npu .npu_add_rms_norm .default ,
73- torch .ops .npu .npu_quantize .default
74- ]
75-
76- def ops_in_model_after (self ) -> List [OpOverload ]:
77- """Return the list of expected operators AFTER successful fusion."""
78- return [torch .ops .npu .npu_add_rms_norm_quant .default ]
79-
8024
8125class TestBackend :
8226 """
@@ -85,14 +29,12 @@ class TestBackend:
8529 records the FX graph before and after the transformation.
8630 """
8731
88- def __init__ (self ):
32+ def __init__ (self , custom_passes : Optional [ List [ Any ]] = None ):
8933 vllm_config = get_current_vllm_config ()
9034 compile_config = vllm_config .compilation_config
91- self .custom_passes = [
92- AddRMSNormQuantFusionPass (vllm_config = vllm_config )
93- ]
9435 self .inductor_config = compile_config .inductor_compile_config
9536 self .inductor_config ["graph_fusion_manager" ] = self .post_pass
37+ self .custom_passes = custom_passes
9638
9739 # Placeholders to store FX graphs for verification
9840 self .graph_pre_pass = None
@@ -179,41 +121,4 @@ def check_after_ops(self, ops: Sequence[OpOverload]):
179121 for op in ops :
180122 num_post = len (self .find_nodes_by_target (self .graph_post_pass , op ))
181123 print (f"Op { op } : post={ num_post } " )
182- assert num_post > 0 , f"Op { op } not found in post-pass graph"
183-
184-
185- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
186- @pytest .mark .parametrize ("hidden_size" , [64 ])
187- @pytest .mark .parametrize ("num_tokens" , [257 ])
188- @pytest .mark .parametrize ("eps" , [1e-5 , 1e-6 ])
189- def test_rmsnorm_quant_fusion (dtype : torch .dtype , hidden_size : int ,
190- num_tokens : int , eps : float ):
191- """
192- End-to-end test for AddRMSNorm+Quantize fusion.
193- Compares: Operator presence/absence before and after graph transformation
194- """
195- torch .set_default_dtype (dtype )
196- torch .manual_seed (1 )
197-
198- vllm_config = VllmConfig (model_config = ModelConfig (dtype = dtype ))
199-
200- with vllm .config .set_current_vllm_config (vllm_config ):
201- backend = TestBackend ()
202- model = TestModel (hidden_size , eps , device = "npu" )
203- model = model .to ("npu" )
204-
205- x = torch .rand (num_tokens ,
206- hidden_size ,
207- device = "npu" ,
208- dtype = dtype ,
209- requires_grad = False )
210-
211- result_unfused = model (x )
212- print ("Unfused result:" , [t .shape for t in result_unfused ])
213- model_fused = torch .compile (model , backend = backend )
214- result_fused = model_fused (x )
215- print ("Fused result:" , [t .shape for t in result_fused ])
216-
217- print ("=== Checking operator fusion ===" )
218- backend .check_before_ops (model .ops_in_model_before ())
219- backend .check_after_ops (model .ops_in_model_after ())
124+ assert num_post > 0 , f"Op { op } not found in post-pass graph"
0 commit comments