|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# Copyright 2023 The vLLM team. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | +from copy import deepcopy |
| 18 | +from typing import Any, Callable, List, Optional, Sequence |
| 19 | + |
| 20 | +import pytest |
| 21 | +import torch |
| 22 | +import torch.fx as fx |
| 23 | +import torch.nn as nn |
| 24 | +import torch_npu |
| 25 | +import vllm.config |
| 26 | +from torch._inductor.decomposition import select_decomp_table |
| 27 | +from vllm.compilation.fx_utils import OpOverload |
| 28 | +from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config |
| 29 | + |
| 30 | +from vllm_ascend.compilation.compiler_interface import compile_fx |
| 31 | +from vllm_ascend.compilation.passes.quant_fusion_pass import \ |
| 32 | + AddRMSNormQuantFusionPass |
| 33 | + |
| 34 | + |
| 35 | +class TestModel(nn.Module): |
| 36 | + """ |
| 37 | + A minimal test model that simulates the pattern: |
| 38 | + AddRMSNorm → Quantization |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"): |
| 42 | + super().__init__() |
| 43 | + self.hidden_size = hidden_size |
| 44 | + self.eps = eps |
| 45 | + self.rms_norm_weight = nn.Parameter( |
| 46 | + torch.randn(hidden_size, device=device)) |
| 47 | + self.quant_scale = torch.tensor([1.0], device=device) |
| 48 | + self.quant_offset = torch.tensor([0.0], device=device) |
| 49 | + |
| 50 | + def forward(self, x): |
| 51 | + """ |
| 52 | + Forward pass: |
| 53 | + 1. Perform npu_add_rms_norm |
| 54 | + 2. Quantize the normalized output to int8 |
| 55 | + Returns both quantized output and updated residual. |
| 56 | + """ |
| 57 | + residual = torch.zeros_like(x) |
| 58 | + |
| 59 | + norm_output, _, new_residual = torch_npu.npu_add_rms_norm( |
| 60 | + x, residual, self.rms_norm_weight, self.eps) |
| 61 | + |
| 62 | + quantized_output = torch_npu.npu_quantize(norm_output, |
| 63 | + self.quant_scale, |
| 64 | + self.quant_offset, |
| 65 | + torch.qint8, -1, False) |
| 66 | + |
| 67 | + return quantized_output, new_residual |
| 68 | + |
| 69 | + def ops_in_model_before(self) -> List[OpOverload]: |
| 70 | + """Return the list of expected operators BEFORE fusion.""" |
| 71 | + return [ |
| 72 | + torch.ops.npu.npu_add_rms_norm.default, |
| 73 | + torch.ops.npu.npu_quantize.default |
| 74 | + ] |
| 75 | + |
| 76 | + def ops_in_model_after(self) -> List[OpOverload]: |
| 77 | + """Return the list of expected operators AFTER successful fusion.""" |
| 78 | + return [torch.ops.npu.npu_add_rms_norm_quant.default] |
| 79 | + |
| 80 | + |
| 81 | +class TestBackend: |
| 82 | + """ |
| 83 | + A custom compilation backend for testing operator fusion passes. |
| 84 | + It applies the AddRMSNormQuantFusionPass during graph compilation and |
| 85 | + records the FX graph before and after the transformation. |
| 86 | + """ |
| 87 | + |
| 88 | + def __init__(self): |
| 89 | + vllm_config = get_current_vllm_config() |
| 90 | + compile_config = vllm_config.compilation_config |
| 91 | + self.custom_passes = [ |
| 92 | + AddRMSNormQuantFusionPass(vllm_config=vllm_config) |
| 93 | + ] |
| 94 | + self.inductor_config = compile_config.inductor_compile_config |
| 95 | + self.inductor_config["graph_fusion_manager"] = self.post_pass |
| 96 | + |
| 97 | + # Placeholders to store FX graphs for verification |
| 98 | + self.graph_pre_pass = None |
| 99 | + self.graph_post_pass = None |
| 100 | + |
| 101 | + def post_pass(self, |
| 102 | + graph: fx.Graph, |
| 103 | + runtime_shape: int | None = None) -> fx.Graph: |
| 104 | + """ |
| 105 | + Apply custom graph transformation passes. |
| 106 | + """ |
| 107 | + self.graph_pre_pass = deepcopy(graph) |
| 108 | + for pass_ in self.custom_passes: |
| 109 | + pass_(graph) |
| 110 | + self.graph_post_pass = deepcopy(graph) |
| 111 | + return graph |
| 112 | + |
| 113 | + def compile( |
| 114 | + self, |
| 115 | + graph: fx.GraphModule, |
| 116 | + example_inputs: list[Any], |
| 117 | + compiler_config: dict[str, Any], |
| 118 | + runtime_shape: Optional[int] = None, |
| 119 | + key: Optional[str] = None |
| 120 | + ) -> tuple[Optional[Callable], Optional[Any]]: |
| 121 | + """ |
| 122 | + Compile the FX graph using vLLM's Ascend compiler interface. |
| 123 | + Wraps the post-pass logic into the inner_compile callback. |
| 124 | + """ |
| 125 | + |
| 126 | + def compile_inner(graph, example_inputs): |
| 127 | + current_pass_manager = compiler_config["graph_fusion_manager"] |
| 128 | + return current_pass_manager(graph, runtime_shape) |
| 129 | + |
| 130 | + decompositions = select_decomp_table() |
| 131 | + compiled_fn = compile_fx( |
| 132 | + graph=graph, |
| 133 | + example_inputs=example_inputs, |
| 134 | + inner_compile=compile_inner, |
| 135 | + decompositions=decompositions, |
| 136 | + ) |
| 137 | + return compiled_fn, None |
| 138 | + |
| 139 | + def __call__(self, gm: fx.GraphModule, example_inputs: List[Any]): |
| 140 | + """ |
| 141 | + Make the backend callable by torch.compile(). |
| 142 | + Returns a compiled executable function. |
| 143 | + """ |
| 144 | + compiled_fn, _ = self.compile( |
| 145 | + gm, |
| 146 | + example_inputs, |
| 147 | + compiler_config={"graph_fusion_manager": self.post_pass}, |
| 148 | + runtime_shape=None, |
| 149 | + key=None, |
| 150 | + ) |
| 151 | + return compiled_fn |
| 152 | + |
| 153 | + def find_nodes_by_target(self, graph: fx.GraphModule, |
| 154 | + target: OpOverload) -> List[fx.Node]: |
| 155 | + """Helper to find all FX nodes that call a specific operator.""" |
| 156 | + return [ |
| 157 | + node for node in graph.graph.nodes |
| 158 | + if hasattr(node, 'target') and node.target == target |
| 159 | + ] |
| 160 | + |
| 161 | + def check_before_ops(self, |
| 162 | + ops: Sequence[OpOverload], |
| 163 | + fully_replaced: bool = True): |
| 164 | + """ |
| 165 | + Verify that the original (unfused) operators exist before the pass |
| 166 | + and are fully removed afterward (if fully_replaced=True). |
| 167 | + """ |
| 168 | + for op in ops: |
| 169 | + num_pre = len(self.find_nodes_by_target(self.graph_pre_pass, op)) |
| 170 | + num_post = len(self.find_nodes_by_target(self.graph_post_pass, op)) |
| 171 | + print(f"Op {op}: pre={num_pre}, post={num_post}") |
| 172 | + |
| 173 | + assert num_pre > 0, f"Op {op} not found in pre-pass graph" |
| 174 | + if fully_replaced: |
| 175 | + assert num_post == 0, f"Unexpected op {op} in post-pass graph: {num_post} nodes remain" |
| 176 | + |
| 177 | + def check_after_ops(self, ops: Sequence[OpOverload]): |
| 178 | + """Verify that the fused operator appears in the transformed graph.""" |
| 179 | + for op in ops: |
| 180 | + num_post = len(self.find_nodes_by_target(self.graph_post_pass, op)) |
| 181 | + print(f"Op {op}: post={num_post}") |
| 182 | + assert num_post > 0, f"Op {op} not found in post-pass graph" |
| 183 | + |
| 184 | + |
| 185 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 186 | +@pytest.mark.parametrize("hidden_size", [64]) |
| 187 | +@pytest.mark.parametrize("num_tokens", [257]) |
| 188 | +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) |
| 189 | +def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int, |
| 190 | + num_tokens: int, eps: float): |
| 191 | + """ |
| 192 | + End-to-end test for AddRMSNorm+Quantize fusion. |
| 193 | + Compares: Operator presence/absence before and after graph transformation |
| 194 | + """ |
| 195 | + torch.set_default_dtype(dtype) |
| 196 | + torch.manual_seed(1) |
| 197 | + |
| 198 | + vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype)) |
| 199 | + |
| 200 | + with vllm.config.set_current_vllm_config(vllm_config): |
| 201 | + backend = TestBackend() |
| 202 | + model = TestModel(hidden_size, eps, device="npu") |
| 203 | + model = model.to("npu") |
| 204 | + |
| 205 | + x = torch.rand(num_tokens, |
| 206 | + hidden_size, |
| 207 | + device="npu", |
| 208 | + dtype=dtype, |
| 209 | + requires_grad=False) |
| 210 | + |
| 211 | + result_unfused = model(x) |
| 212 | + print("Unfused result:", [t.shape for t in result_unfused]) |
| 213 | + model_fused = torch.compile(model, backend=backend) |
| 214 | + result_fused = model_fused(x) |
| 215 | + print("Fused result:", [t.shape for t in result_fused]) |
| 216 | + |
| 217 | + print("=== Checking operator fusion ===") |
| 218 | + backend.check_before_ops(model.ops_in_model_before()) |
| 219 | + backend.check_after_ops(model.ops_in_model_after()) |
0 commit comments