Skip to content

Commit 3bd9c49

Browse files
shen-shanshangcanlintjtanaa
authored
[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)
Signed-off-by: shen-shanshan <[email protected]> Co-authored-by: gcanlin <[email protected]> Co-authored-by: TJian <[email protected]>
1 parent ff21a0f commit 3bd9c49

File tree

14 files changed

+553
-280
lines changed

14 files changed

+553
-280
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
5+
6+
This test ensures that RotaryEmbedding classes correctly call the appropriate
7+
ApplyRotaryEmb methods based on the calling context:
8+
9+
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
10+
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
11+
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
12+
"""
13+
14+
from dataclasses import dataclass
15+
16+
import pytest
17+
import torch
18+
19+
from vllm.config import (
20+
CompilationConfig,
21+
VllmConfig,
22+
get_cached_compilation_config,
23+
set_current_vllm_config,
24+
)
25+
from vllm.platforms import current_platform
26+
27+
CUDA_DEVICES = ["cuda:0"]
28+
29+
30+
@dataclass
31+
class RotaryEmbeddingTestCase:
32+
"""Test case configuration for RotaryEmbedding dispatch tests."""
33+
34+
name: str
35+
rope_class: type
36+
rope_kwargs: dict
37+
method_name: str # forward_native, forward_cuda, forward
38+
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
39+
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
40+
expect_forward: bool # Should call ApplyRotaryEmb.forward()
41+
42+
43+
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
44+
"""Generate test cases for all RotaryEmbedding classes."""
45+
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
46+
Ernie4_5_VLRotaryEmbedding,
47+
)
48+
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
49+
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
50+
51+
common_kwargs = {
52+
"head_size": 128,
53+
"rotary_dim": 128,
54+
"max_position_embeddings": 4096,
55+
"base": 10000,
56+
"is_neox_style": True,
57+
"dtype": torch.bfloat16,
58+
}
59+
60+
return [
61+
# MRotaryEmbedding tests
62+
RotaryEmbeddingTestCase(
63+
name="MRotaryEmbedding.forward_native",
64+
rope_class=MRotaryEmbedding,
65+
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
66+
method_name="forward_native",
67+
positions_shape=(3, 32), # 2D for multimodal
68+
expect_forward_native=True,
69+
expect_forward=False,
70+
),
71+
RotaryEmbeddingTestCase(
72+
name="MRotaryEmbedding.forward_cuda_1d",
73+
rope_class=MRotaryEmbedding,
74+
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
75+
method_name="forward_cuda",
76+
positions_shape=(32,), # 1D triggers apply_rotary_emb path
77+
expect_forward_native=False,
78+
expect_forward=True,
79+
),
80+
# XDRotaryEmbedding tests
81+
RotaryEmbeddingTestCase(
82+
name="XDRotaryEmbedding.forward",
83+
rope_class=XDRotaryEmbedding,
84+
rope_kwargs={
85+
**common_kwargs,
86+
"scaling_alpha": 1.0,
87+
"xdrope_section": [16, 16, 16, 16],
88+
},
89+
method_name="forward",
90+
positions_shape=(4, 32), # 4D for P/W/H/T
91+
expect_forward_native=False,
92+
expect_forward=True,
93+
),
94+
# Ernie4_5_VLRotaryEmbedding tests
95+
RotaryEmbeddingTestCase(
96+
name="Ernie4_5_VLRotaryEmbedding.forward_native",
97+
rope_class=Ernie4_5_VLRotaryEmbedding,
98+
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
99+
method_name="forward_native",
100+
positions_shape=(3, 32), # 2D for multimodal
101+
expect_forward_native=True,
102+
expect_forward=False,
103+
),
104+
]
105+
106+
107+
def run_dispatch_test(
108+
test_case: RotaryEmbeddingTestCase,
109+
device: str,
110+
):
111+
"""Run a dispatch test for a RotaryEmbedding class."""
112+
vllm_config = VllmConfig(
113+
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
114+
)
115+
get_cached_compilation_config.cache_clear()
116+
117+
with set_current_vllm_config(vllm_config):
118+
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
119+
120+
apply_rotary_emb = rope.apply_rotary_emb
121+
122+
# Verify custom op is enabled
123+
if test_case.expect_forward_native:
124+
assert (
125+
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
126+
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
127+
128+
# Setup call tracking
129+
call_tracker = {"forward_native_called": False, "forward_called": False}
130+
original_forward_native = apply_rotary_emb.forward_native
131+
original_forward = apply_rotary_emb.forward
132+
133+
def tracked_forward_native(*args, **kwargs):
134+
call_tracker["forward_native_called"] = True
135+
return original_forward_native(*args, **kwargs)
136+
137+
def tracked_forward(*args, **kwargs):
138+
call_tracker["forward_called"] = True
139+
return original_forward(*args, **kwargs)
140+
141+
apply_rotary_emb.forward_native = tracked_forward_native
142+
apply_rotary_emb.forward = tracked_forward
143+
144+
try:
145+
num_tokens = test_case.positions_shape[-1]
146+
num_q_heads = 8
147+
num_kv_heads = 2
148+
head_size = test_case.rope_kwargs["head_size"]
149+
max_position = test_case.rope_kwargs["max_position_embeddings"]
150+
151+
positions = torch.randint(
152+
0, max_position // 4, test_case.positions_shape, device=device
153+
)
154+
query = torch.randn(
155+
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
156+
)
157+
key = torch.randn(
158+
num_tokens,
159+
num_kv_heads * head_size,
160+
dtype=torch.bfloat16,
161+
device=device,
162+
)
163+
164+
# Call the method under test
165+
method = getattr(rope, test_case.method_name)
166+
method(positions, query.clone(), key.clone())
167+
168+
# Verify expectations
169+
if test_case.expect_forward_native:
170+
assert call_tracker["forward_native_called"], (
171+
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
172+
)
173+
if not test_case.expect_forward:
174+
assert not call_tracker["forward_called"], (
175+
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
176+
"Bug: when +apply_rotary_emb is enabled, forward_native() "
177+
"incorrectly dispatches to CUDA/HIP kernels."
178+
)
179+
if test_case.expect_forward:
180+
assert call_tracker["forward_called"], (
181+
f"{test_case.name} should call ApplyRotaryEmb.forward()"
182+
)
183+
finally:
184+
apply_rotary_emb.forward_native = original_forward_native
185+
apply_rotary_emb.forward = original_forward
186+
187+
188+
@pytest.mark.skipif(
189+
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
190+
)
191+
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
192+
@pytest.mark.parametrize("device", CUDA_DEVICES)
193+
def test_rotary_embedding_dispatch(
194+
test_case: RotaryEmbeddingTestCase,
195+
device: str,
196+
):
197+
"""
198+
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
199+
200+
- forward_native methods should call ApplyRotaryEmb.forward_native()
201+
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
202+
"""
203+
run_dispatch_test(test_case, device)

vllm/model_executor/layers/rotary_embedding/base.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from vllm._aiter_ops import rocm_aiter_ops
88
from vllm.model_executor.custom_op import CustomOp
99

10-
from .common import apply_rotary_emb_torch
10+
from .common import ApplyRotaryEmb
1111

1212

1313
@CustomOp.register("rotary_embedding")
@@ -49,6 +49,10 @@ def __init__(
4949
rocm_aiter_ops.is_triton_rotary_embed_enabled()
5050
)
5151

52+
self.apply_rotary_emb = ApplyRotaryEmb(
53+
is_neox_style=self.is_neox_style,
54+
)
55+
5256
def _compute_inv_freq(self, base: float) -> torch.Tensor:
5357
"""Compute the inverse frequency."""
5458
# NOTE(woosuk): To exactly match the HF implementation, we need to
@@ -123,7 +127,12 @@ def forward_static(
123127
query = query.view(num_tokens, -1, head_size)
124128
query_rot = query[..., :rotary_dim]
125129
query_pass = query[..., rotary_dim:]
126-
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
130+
query_rot = ApplyRotaryEmb.forward_static(
131+
query_rot,
132+
cos,
133+
sin,
134+
is_neox_style,
135+
)
127136
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
128137

129138
# key may be None in some cases, e.g. cross-layer KV sharing
@@ -132,7 +141,12 @@ def forward_static(
132141
key = key.view(num_tokens, -1, head_size)
133142
key_rot = key[..., :rotary_dim]
134143
key_pass = key[..., rotary_dim:]
135-
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
144+
key_rot = ApplyRotaryEmb.forward_static(
145+
key_rot,
146+
cos,
147+
sin,
148+
is_neox_style,
149+
)
136150
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
137151
return query, key
138152

0 commit comments

Comments
 (0)