Skip to content

Commit 04878a1

Browse files
whx-sjtuchencangtao
authored andcommitted
[Ops][Triton] Add a triton kernel supporting partial rope. (vllm-project#4413)
### What this PR does / why we need it? This PR adds a triton rope kernel witch supports scenarios of `rope_dim != head_dim`. This can save the split op before rope and the concat op after rope. Profiling shows improvement. ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? I will add related ut after ci integrated with triton. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: whx-sjtu <[email protected]>
1 parent 2d9e496 commit 04878a1

File tree

6 files changed

+423
-22
lines changed

6 files changed

+423
-22
lines changed

tests/e2e/nightly/ops/triton/__init__.py

Whitespace-only changes.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import gc
2+
3+
import pytest
4+
import torch
5+
6+
from vllm_ascend.ops.triton.rope import rope_forward_triton
7+
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
8+
9+
IS_NEOX_STYLE = [True, False]
10+
DTYPES = [torch.bfloat16, torch.float16]
11+
HEAD_SIZES = [64, 128]
12+
ROTARY_DIMS = [32, 64]
13+
NUM_Q_HEADS = [64]
14+
NUM_K_HEADS = [1]
15+
NUM_TOKENS = [1, 4, 8, 16, 1024]
16+
SEEDS = [0]
17+
DEVICES = [f"npu:{0}"]
18+
DEFAULT_ATOL = 1e-3
19+
DEFAULT_RTOL = 1e-3
20+
21+
22+
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
23+
x1 = x[..., :x.shape[-1] // 2]
24+
x2 = x[..., x.shape[-1] // 2:]
25+
return torch.cat((-x2, x1), dim=-1)
26+
27+
28+
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
29+
x1 = x[..., ::2]
30+
x2 = x[..., 1::2]
31+
x = torch.stack((-x2, x1), dim=-1)
32+
return x.flatten(-2)
33+
34+
35+
def _rope_pytorch_native(
36+
query, key, cos, sin, rope_dim,
37+
is_neox_style) -> tuple[torch.Tensor, torch.Tensor | None]:
38+
"""PyTorch-native implementation equivalent to forward()."""
39+
assert key is not None
40+
orig_dtype = query.dtype
41+
query_rot = query[..., :rope_dim].to(torch.float32)
42+
key_rot = key[..., :rope_dim].to(torch.float32)
43+
head_size = query.shape[-1]
44+
if rope_dim < head_size:
45+
query_pass = query[..., rope_dim:]
46+
key_pass = key[..., rope_dim:]
47+
48+
if is_neox_style:
49+
cos = cos.repeat(1, 2).unsqueeze(-2).to(torch.float32)
50+
sin = sin.repeat(1, 2).unsqueeze(-2).to(torch.float32)
51+
else:
52+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
53+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2).to(torch.float32)
54+
55+
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
56+
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
57+
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
58+
59+
if rope_dim < head_size:
60+
query = torch.cat((query_rot.to(orig_dtype), query_pass), dim=-1)
61+
key = torch.cat((key_rot.to(orig_dtype), key_pass), dim=-1)
62+
else:
63+
query = query_rot.to(orig_dtype)
64+
key = key_rot.to(orig_dtype)
65+
return query, key
66+
67+
68+
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
69+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
70+
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
71+
@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS)
72+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
73+
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
74+
@pytest.mark.parametrize("dtype", DTYPES)
75+
@pytest.mark.parametrize("seed", SEEDS)
76+
@pytest.mark.parametrize("device", DEVICES)
77+
@torch.inference_mode()
78+
def test_rotary_embedding_triton_kernel(
79+
is_neox_style: bool,
80+
num_tokens: int,
81+
num_q_heads: int,
82+
num_k_heads: int,
83+
head_size: int,
84+
rotary_dim: int,
85+
dtype: torch.dtype,
86+
seed: int,
87+
device: str,
88+
) -> None:
89+
torch.manual_seed(seed)
90+
torch.set_default_device(device)
91+
init_device_properties_triton()
92+
if rotary_dim == -1:
93+
rotary_dim = head_size
94+
sin = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
95+
cos = torch.randn(num_tokens, rotary_dim // 2, dtype=dtype, device=device)
96+
q_trt = torch.randn(num_tokens,
97+
num_q_heads,
98+
head_size,
99+
dtype=dtype,
100+
device=device)
101+
k_trt = torch.randn(num_tokens,
102+
num_k_heads,
103+
head_size,
104+
dtype=dtype,
105+
device=device)
106+
q_gold = torch.randn(num_tokens,
107+
num_q_heads,
108+
head_size,
109+
dtype=dtype,
110+
device=device)
111+
k_gold = torch.randn(num_tokens,
112+
num_k_heads,
113+
head_size,
114+
dtype=dtype,
115+
device=device)
116+
q_trt.copy_(q_gold)
117+
k_trt.copy_(k_gold)
118+
q_trt, k_trt = rope_forward_triton(q_trt,
119+
k_trt,
120+
cos,
121+
sin,
122+
rope_dim=rotary_dim,
123+
is_neox_style=is_neox_style)
124+
q_gold, k_gold = _rope_pytorch_native(q_gold,
125+
k_gold,
126+
cos,
127+
sin,
128+
rope_dim=rotary_dim,
129+
is_neox_style=is_neox_style)
130+
# Compare the results.
131+
torch.testing.assert_close(q_trt.view(q_gold.size()),
132+
q_gold,
133+
atol=DEFAULT_ATOL,
134+
rtol=DEFAULT_RTOL)
135+
torch.testing.assert_close(k_trt.view(k_gold.size()),
136+
k_gold,
137+
atol=DEFAULT_ATOL,
138+
rtol=DEFAULT_RTOL)
139+
gc.collect()
140+
torch.npu.empty_cache()
141+
torch.npu.reset_peak_memory_stats()

vllm_ascend/attention/sfa_v1.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
from vllm.distributed import get_tensor_model_parallel_world_size
1010
from vllm.model_executor.layers.linear import (LinearBase,
1111
UnquantizedLinearMethod)
12+
from vllm.triton_utils import HAS_TRITON
1213
from vllm.v1.attention.backends.utils import AttentionCGSupport
1314

1415
from vllm_ascend.ascend_config import get_ascend_config
1516
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1617
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
1718
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
1819
wait_for_kv_layer_from_connector)
20+
from vllm_ascend.ops.triton.rope import rope_forward_triton
1921
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
2022
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
2123
is_enable_nz)
@@ -492,35 +494,50 @@ def indexer_select(
492494
cos = attn_metadata.cos
493495
sin = attn_metadata.sin
494496

495-
cos_q, sin_q = cos, sin
496-
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
497-
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
498-
499497
# q process in new stream
500498
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
501-
q = q.view(-1, self.n_head, self.head_dim) # [b,s,64,128]
502-
q_pe, q_nope = torch.split(
503-
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
504-
dim=-1) # [b,s,64,64+64]
505-
506-
q_pe = q_pe.unsqueeze(2)
507-
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
508-
q_pe = q_pe.squeeze(2)
509-
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
499+
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
510500

511501
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
512502
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
513503
k_proj, need_gather_q_kv)
514504
k = self.k_norm(k_proj).unsqueeze(1)
515-
k_pe, k_nope = torch.split(
516-
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
517-
dim=-1) # [b,s,64+64]
518-
519-
k_pe = k_pe.unsqueeze(2)
520-
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
521-
k_pe = k_pe.squeeze(2)
522-
523-
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
505+
k = k.view(-1, 1, self.head_dim)
506+
507+
if HAS_TRITON:
508+
cos = cos.view(-1, self.qk_rope_head_dim)
509+
sin = sin.view(-1, self.qk_rope_head_dim)
510+
q, k = rope_forward_triton(q,
511+
k,
512+
cos,
513+
sin,
514+
rope_dim=self.qk_rope_head_dim,
515+
is_neox_style=True)
516+
else:
517+
cos_q, sin_q = cos, sin
518+
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
519+
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
520+
521+
q_pe, q_nope = torch.split(
522+
q,
523+
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
524+
dim=-1) # [b,s,64,64+64]
525+
526+
q_pe = q_pe.unsqueeze(2)
527+
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
528+
q_pe = q_pe.squeeze(2)
529+
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
530+
531+
k_pe, k_nope = torch.split(
532+
k,
533+
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
534+
dim=-1) # [b,s,64+64]
535+
536+
k_pe = k_pe.unsqueeze(2)
537+
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
538+
k_pe = k_pe.squeeze(2)
539+
540+
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
524541

525542
if kv_cache is not None:
526543
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),

0 commit comments

Comments
 (0)