Skip to content

Commit 28cba2d

Browse files
zhoux77899Angazenn
authored andcommitted
[v0.11.0][Feat] Prefetching Attention QKV Linear Weight With AddRmsNormQuant Custom Op (vllm-project#3649)
### What this PR does / why we need it? - `qkv_proj.weight` prefetching has been implemented with `Quant` op, when `AddRmsNormQuant` is enabled (vllm-project#3465) `qkv_proj.weight` prefetching won't work - Implement `qkv_proj.weight` prefetching with `AddRmsNormQuant`, which has been merged on `main` branch (vllm-project#3517) ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Tested on `Qwen3-235B-A22B-W8A8` <img width="1868" height="109" alt="image" src="https://github.com/user-attachments/assets/0bc28082-0287-4d5c-b8f6-f907c3134d36" /> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: zhoux77899 <[email protected]>
1 parent f645a5b commit 28cba2d

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

tests/ut/ops/test_layernorm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
117117
mock_forward_context.layer_idx = 0
118118
mock_forward_context.num_hidden_layers = num_hidden_layers
119119
mock_forward_context.fusion_linear = "gate_up_dense"
120+
mock_forward_context.weight_prefetch_method = None
120121

121122
# Ensure fusion and layer_idx increment are handled correctly
122123
x = torch.randn(4, 8, dtype=torch.float16)
@@ -125,28 +126,28 @@ def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
125126

126127
x_out, residual_out = layer.forward_oot(x, residual)
127128

128-
assert mock_get_forward_context.call_count == 1
129+
assert mock_get_forward_context.call_count == 2
129130
assert mock_forward_context.fusion_linear == "qkv_dense"
130131
assert mock_forward_context.layer_idx == 1
131132

132133
x_out, residual_out = layer.forward_oot(x, residual)
133134

134-
assert mock_get_forward_context.call_count == 2
135+
assert mock_get_forward_context.call_count == 4
135136
assert mock_forward_context.fusion_linear == "gate_up_dense"
136137
assert mock_forward_context.layer_idx == 1
137138

138139
if torch_npu_check:
139140
mock_forward_context.fusion_linear = "gate_moe"
140141
x_out, residual_out = layer.forward_oot(x, residual)
141142

142-
assert mock_get_forward_context.call_count == 3
143+
assert mock_get_forward_context.call_count == 6
143144
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
144145
assert mock_forward_context.fusion_linear == fusion_linear_expected
145146
assert mock_forward_context.layer_idx == 2
146147

147148
x_out, residual_out = layer.forward_oot(x, residual)
148149

149-
assert mock_get_forward_context.call_count == 4
150+
assert mock_get_forward_context.call_count == 7
150151
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
151152
assert mock_forward_context.fusion_linear == fusion_linear_expected
152153
assert mock_forward_context.layer_idx == 2
@@ -156,13 +157,13 @@ def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
156157
# last layer returned directly
157158
x_out, residual_out = layer.forward_oot(x, residual)
158159

159-
assert mock_get_forward_context.call_count == 5
160+
assert mock_get_forward_context.call_count == 8
160161
assert mock_forward_context.fusion_linear == "qkv_moe"
161162
assert mock_forward_context.layer_idx == 3
162163

163164
x_out, residual_out = layer.forward_oot(x, residual)
164165

165-
assert mock_get_forward_context.call_count == 6
166+
assert mock_get_forward_context.call_count == 9
166167
assert mock_forward_context.fusion_linear == "qkv_moe"
167168
assert mock_forward_context.layer_idx == 3
168169

vllm_ascend/ops/layernorm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ def _addrmsnorm_forward_oot(
3838

3939
torch_npu_check = version_check()
4040
if layer is not None and not is_310p():
41+
layer_cls_name = layer.__class__.__name__
42+
try:
43+
weight_prefetch_method = get_forward_context(
44+
).weight_prefetch_method
45+
except AssertionError:
46+
weight_prefetch_method = None
47+
48+
# prefetch qkvo_proj.weight preprocess
49+
if weight_prefetch_method:
50+
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
51+
layer_cls_name=layer_cls_name,
52+
weight=layer.weight,
53+
start_flag=x,
54+
)
55+
# add_rms_norm_quant
4156
if torch_npu_check:
4257
x, _, residual = torch_npu.npu_add_rms_norm_quant(
4358
x,
@@ -55,6 +70,13 @@ def _addrmsnorm_forward_oot(
5570
layer.aclnn_input_scale,
5671
layer.aclnn_input_offset,
5772
epsilon=self.variance_epsilon)
73+
# prefetch qkvo_proj.weight postprocess
74+
if weight_prefetch_method:
75+
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
76+
layer_cls_name=layer_cls_name,
77+
stop_flag=x,
78+
)
79+
5880
else:
5981
if is_310p():
6082
orig_dtype = residual.dtype

0 commit comments

Comments
 (0)