Skip to content

Commit 2502e02

Browse files
author
weijinqian_v1
committed
Merge remote-tracking branch 'upstream/main' into refactor_attention_main_second
Signed-off-by: weijinqian_v1 <[email protected]>
2 parents fa6b91b + da84eb2 commit 2502e02

18 files changed

+601
-1217
lines changed

.github/workflows/vllm_ascend_test_pr_light.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,12 @@ jobs:
135135
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
136136
--ignore tests/ut/torchair/models/test_torchair_deepseek_mtp.py \
137137
--ignore tests/ut/torchair/models/test_torchair_deepseek_v2.py \
138-
--ignore tests/ut/models/test_qwen2_vl.py \
139-
--ignore tests/ut/models/test_qwen2_5_vl.py \
140-
--ignore tests/ut/models/test_qwen2_5_vl_without_padding.py \
141-
--ignore tests/ut/model_loder
138+
--ignore tests/ut/model_loader/netloader/test_netloader_elastic.py \
139+
--ignore tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
140+
--ignore tests/ut/kv_connector/test_remote_decode_lifecycle.py \
141+
--ignore tests/ut/kv_connector/test_llmdatadist_connector.py \
142+
--ignore tests/ut/ops/test_linear.py \
143+
--ignore tests/ut/core/test_scheduler_dynamic_batch.py
142144
143145
- name: Upload coverage to Codecov
144146
# only upload coverage when commits merged
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from copy import deepcopy
18+
from typing import Any, Callable, List, Optional, Sequence
19+
20+
import pytest
21+
import torch
22+
import torch.fx as fx
23+
import torch.nn as nn
24+
import torch_npu
25+
import vllm.config
26+
from torch._inductor.decomposition import select_decomp_table
27+
from vllm.compilation.fx_utils import OpOverload
28+
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
29+
30+
from vllm_ascend.compilation.compiler_interface import compile_fx
31+
from vllm_ascend.compilation.passes.quant_fusion_pass import \
32+
AddRMSNormQuantFusionPass
33+
34+
35+
class TestModel(nn.Module):
36+
"""
37+
A minimal test model that simulates the pattern:
38+
AddRMSNorm → Quantization
39+
"""
40+
41+
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
42+
super().__init__()
43+
self.hidden_size = hidden_size
44+
self.eps = eps
45+
self.rms_norm_weight = nn.Parameter(
46+
torch.randn(hidden_size, device=device))
47+
self.quant_scale = torch.tensor([1.0], device=device)
48+
self.quant_offset = torch.tensor([0.0], device=device)
49+
50+
def forward(self, x):
51+
"""
52+
Forward pass:
53+
1. Perform npu_add_rms_norm
54+
2. Quantize the normalized output to int8
55+
Returns both quantized output and updated residual.
56+
"""
57+
residual = torch.zeros_like(x)
58+
59+
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
60+
x, residual, self.rms_norm_weight, self.eps)
61+
62+
quantized_output = torch_npu.npu_quantize(norm_output,
63+
self.quant_scale,
64+
self.quant_offset,
65+
torch.qint8, -1, False)
66+
67+
return quantized_output, new_residual
68+
69+
def ops_in_model_before(self) -> List[OpOverload]:
70+
"""Return the list of expected operators BEFORE fusion."""
71+
return [
72+
torch.ops.npu.npu_add_rms_norm.default,
73+
torch.ops.npu.npu_quantize.default
74+
]
75+
76+
def ops_in_model_after(self) -> List[OpOverload]:
77+
"""Return the list of expected operators AFTER successful fusion."""
78+
return [torch.ops.npu.npu_add_rms_norm_quant.default]
79+
80+
81+
class TestBackend:
82+
"""
83+
A custom compilation backend for testing operator fusion passes.
84+
It applies the AddRMSNormQuantFusionPass during graph compilation and
85+
records the FX graph before and after the transformation.
86+
"""
87+
88+
def __init__(self):
89+
vllm_config = get_current_vllm_config()
90+
compile_config = vllm_config.compilation_config
91+
self.custom_passes = [
92+
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
93+
]
94+
self.inductor_config = compile_config.inductor_compile_config
95+
self.inductor_config["graph_fusion_manager"] = self.post_pass
96+
97+
# Placeholders to store FX graphs for verification
98+
self.graph_pre_pass = None
99+
self.graph_post_pass = None
100+
101+
def post_pass(self,
102+
graph: fx.Graph,
103+
runtime_shape: int | None = None) -> fx.Graph:
104+
"""
105+
Apply custom graph transformation passes.
106+
"""
107+
self.graph_pre_pass = deepcopy(graph)
108+
for pass_ in self.custom_passes:
109+
pass_(graph)
110+
self.graph_post_pass = deepcopy(graph)
111+
return graph
112+
113+
def compile(
114+
self,
115+
graph: fx.GraphModule,
116+
example_inputs: list[Any],
117+
compiler_config: dict[str, Any],
118+
runtime_shape: Optional[int] = None,
119+
key: Optional[str] = None
120+
) -> tuple[Optional[Callable], Optional[Any]]:
121+
"""
122+
Compile the FX graph using vLLM's Ascend compiler interface.
123+
Wraps the post-pass logic into the inner_compile callback.
124+
"""
125+
126+
def compile_inner(graph, example_inputs):
127+
current_pass_manager = compiler_config["graph_fusion_manager"]
128+
return current_pass_manager(graph, runtime_shape)
129+
130+
decompositions = select_decomp_table()
131+
compiled_fn = compile_fx(
132+
graph=graph,
133+
example_inputs=example_inputs,
134+
inner_compile=compile_inner,
135+
decompositions=decompositions,
136+
)
137+
return compiled_fn, None
138+
139+
def __call__(self, gm: fx.GraphModule, example_inputs: List[Any]):
140+
"""
141+
Make the backend callable by torch.compile().
142+
Returns a compiled executable function.
143+
"""
144+
compiled_fn, _ = self.compile(
145+
gm,
146+
example_inputs,
147+
compiler_config={"graph_fusion_manager": self.post_pass},
148+
runtime_shape=None,
149+
key=None,
150+
)
151+
return compiled_fn
152+
153+
def find_nodes_by_target(self, graph: fx.GraphModule,
154+
target: OpOverload) -> List[fx.Node]:
155+
"""Helper to find all FX nodes that call a specific operator."""
156+
return [
157+
node for node in graph.graph.nodes
158+
if hasattr(node, 'target') and node.target == target
159+
]
160+
161+
def check_before_ops(self,
162+
ops: Sequence[OpOverload],
163+
fully_replaced: bool = True):
164+
"""
165+
Verify that the original (unfused) operators exist before the pass
166+
and are fully removed afterward (if fully_replaced=True).
167+
"""
168+
for op in ops:
169+
num_pre = len(self.find_nodes_by_target(self.graph_pre_pass, op))
170+
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
171+
print(f"Op {op}: pre={num_pre}, post={num_post}")
172+
173+
assert num_pre > 0, f"Op {op} not found in pre-pass graph"
174+
if fully_replaced:
175+
assert num_post == 0, f"Unexpected op {op} in post-pass graph: {num_post} nodes remain"
176+
177+
def check_after_ops(self, ops: Sequence[OpOverload]):
178+
"""Verify that the fused operator appears in the transformed graph."""
179+
for op in ops:
180+
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
181+
print(f"Op {op}: post={num_post}")
182+
assert num_post > 0, f"Op {op} not found in post-pass graph"
183+
184+
185+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
186+
@pytest.mark.parametrize("hidden_size", [64])
187+
@pytest.mark.parametrize("num_tokens", [257])
188+
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
189+
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
190+
num_tokens: int, eps: float):
191+
"""
192+
End-to-end test for AddRMSNorm+Quantize fusion.
193+
Compares: Operator presence/absence before and after graph transformation
194+
"""
195+
torch.set_default_dtype(dtype)
196+
torch.manual_seed(1)
197+
198+
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
199+
200+
with vllm.config.set_current_vllm_config(vllm_config):
201+
backend = TestBackend()
202+
model = TestModel(hidden_size, eps, device="npu")
203+
model = model.to("npu")
204+
205+
x = torch.rand(num_tokens,
206+
hidden_size,
207+
device="npu",
208+
dtype=dtype,
209+
requires_grad=False)
210+
211+
result_unfused = model(x)
212+
print("Unfused result:", [t.shape for t in result_unfused])
213+
model_fused = torch.compile(model, backend=backend)
214+
result_fused = model_fused(x)
215+
print("Fused result:", [t.shape for t in result_fused])
216+
217+
print("=== Checking operator fusion ===")
218+
backend.check_before_ops(model.ops_in_model_before())
219+
backend.check_after_ops(model.ops_in_model_after())

tests/ut/core/test_schedule_config.py

Lines changed: 0 additions & 134 deletions
This file was deleted.

0 commit comments

Comments
 (0)