Skip to content

Commit 2d30141

Browse files
committed
add triton utils and move to triton dir
Signed-off-by: whx-sjtu <[email protected]>
1 parent 2cfb6b6 commit 2d30141

File tree

5 files changed

+226
-176
lines changed

5 files changed

+226
-176
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
1818
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
1919
wait_for_kv_layer_from_connector)
20-
from vllm_ascend.ops.rotary_embedding import rope_forward_triton
20+
from vllm_ascend.ops.triton.rope import rope_forward_triton
2121
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
2222
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
2323
is_enable_nz)

vllm_ascend/ops/rotary_embedding.py

Lines changed: 0 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -25,187 +25,12 @@
2525
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
2626
YaRNScalingRotaryEmbedding)
2727
from vllm.platforms import CpuArchEnum
28-
from vllm.triton_utils import tl, triton
2928

3029
from vllm_ascend.platform import NPUPlatform
3130
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
3231
get_ascend_device_type)
3332

3433

35-
@triton.jit
36-
def _triton_rope(
37-
q_ptr,
38-
q_row_stride,
39-
k_ptr,
40-
k_row_stride,
41-
cos,
42-
cos_row_stride,
43-
sin,
44-
sin_row_stride,
45-
num_tokens,
46-
n_qh: tl.constexpr,
47-
n_kh: tl.constexpr,
48-
hd: tl.constexpr,
49-
rope_dim: tl.constexpr,
50-
pad_n_qh: tl.constexpr,
51-
pad_n_kh: tl.constexpr,
52-
pad_rope_dim: tl.constexpr,
53-
BLOCK_SIZE: tl.constexpr,
54-
IS_NEOX_STYLE: tl.constexpr,
55-
):
56-
"""
57-
This triton kernel applies rotary embedding on q and k.
58-
It supports rope_dim != head_dim scenario.
59-
It supports both neox style and non-neox style rope computation.
60-
61-
Input tensor layout assumptions:
62-
63-
q size: (num_tokens, num_q_heads, head_dim)
64-
q stride: (num_q_heads * head_dim, head_dim, 1)
65-
k size: (num_tokens, num_kv_heads, head_dim)
66-
k stride: (num_kv_heads * head_dim, head_dim, 1)
67-
cos/sin size: (num_tokens, rope_dim/2)
68-
cos/sin stride: (rope_dim/2, 1)
69-
70-
Different compute pattern of IS_NEOX_STYLE:
71-
72-
if IS_NEOX_STYLE:
73-
x1, x2 = torch.chunk(x, 2, dim=-1)
74-
else:
75-
x1 = x[..., ::2]
76-
x2 = x[..., 1::2]
77-
o1 = x1 * cos - x2 * sin
78-
o2 = x2 * cos + x1 * sin
79-
if IS_NEOX_STYLE:
80-
return torch.cat((o1, o2), dim=-1)
81-
else:
82-
return torch.stack((o1, o2), dim=-1).flatten(-2)
83-
"""
84-
pid = tl.program_id(0).to(tl.int64)
85-
row_block_size = tl.num_programs(0)
86-
87-
for row_idx in tl.range(pid, num_tokens, row_block_size):
88-
q_start_ptr = q_ptr + row_idx * q_row_stride
89-
k_start_ptr = k_ptr + row_idx * k_row_stride
90-
91-
# ####################################################################
92-
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
93-
# m of this program instance
94-
# ####################################################################
95-
cos_start_ptr = cos + row_idx * cos_row_stride
96-
sin_start_ptr = sin + row_idx * sin_row_stride
97-
98-
cos_offsets = tl.arange(0, pad_rope_dim // 2)
99-
cos_mask = cos_offsets < (rope_dim // 2)
100-
cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
101-
sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32)
102-
103-
# ####################################################################
104-
# Load the left and right half of q and k for the current
105-
# program instance (i.e. for the current token) separately
106-
# ####################################################################
107-
# left half of the head
108-
if IS_NEOX_STYLE:
109-
first_half_q_offsets = tl.arange(0,
110-
pad_n_qh)[:, None] * hd + tl.arange(
111-
0, pad_rope_dim // 2)[None, :]
112-
first_half_k_offsets = tl.arange(0,
113-
pad_n_kh)[:, None] * hd + tl.arange(
114-
0, pad_rope_dim // 2)[None, :]
115-
else:
116-
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + (
117-
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
118-
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + (
119-
2 * tl.arange(0, pad_rope_dim // 2)[None, :])
120-
121-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
122-
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
123-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
124-
0, pad_rope_dim // 2)[None, :] < (rope_dim // 2))
125-
q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets,
126-
mask=first_q_mask,
127-
other=0).to(sin_row.dtype)
128-
k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets,
129-
mask=first_k_mask,
130-
other=0).to(sin_row.dtype)
131-
132-
# right half of the head
133-
if IS_NEOX_STYLE:
134-
second_half_q_offsets = first_half_q_offsets + (rope_dim // 2)
135-
second_half_k_offsets = first_half_k_offsets + (rope_dim // 2)
136-
else:
137-
second_half_q_offsets = first_half_q_offsets + 1
138-
second_half_k_offsets = first_half_k_offsets + 1
139-
second_q_mask = first_q_mask
140-
second_k_mask = first_k_mask
141-
q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets,
142-
mask=second_q_mask,
143-
other=0).to(sin_row.dtype)
144-
k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets,
145-
mask=second_k_mask,
146-
other=0).to(sin_row.dtype)
147-
148-
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
149-
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
150-
tl.store(q_start_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
151-
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
152-
tl.store(q_start_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
153-
154-
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
155-
tl.store(k_start_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
156-
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
157-
tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
158-
159-
160-
def rope_forward_triton(q,
161-
k,
162-
cos,
163-
sin,
164-
rope_dim: int = -1,
165-
is_neox_style: bool = True):
166-
if not q.is_contiguous():
167-
q = q.contiguous()
168-
if not k.is_contiguous():
169-
k = k.contiguous()
170-
171-
num_tokens, n_q_head, head_dim = q.shape
172-
n_kv_head = k.shape[1]
173-
cos = cos.view(num_tokens, -1)
174-
sin = sin.view(num_tokens, -1)
175-
if rope_dim == -1:
176-
# If rope_dim is not specified, we assume that input cos/sin is not
177-
# duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2
178-
rope_dim = cos.shape[-1] * 2
179-
assert rope_dim <= head_dim
180-
pad_rope_dim = triton.next_power_of_2(rope_dim)
181-
pad_n_q_head = triton.next_power_of_2(n_q_head)
182-
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
183-
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
184-
n_row = min(num_tokens, NUM_VECTORCORE)
185-
186-
_triton_rope[(n_row, )](
187-
q,
188-
q.stride(0),
189-
k,
190-
k.stride(0),
191-
cos,
192-
cos.stride(0),
193-
sin,
194-
sin.stride(0),
195-
num_tokens,
196-
n_q_head,
197-
n_kv_head,
198-
head_dim,
199-
rope_dim,
200-
pad_n_q_head,
201-
pad_n_kv_head,
202-
pad_rope_dim,
203-
BLOCK_SIZE=BLOCK_SIZE,
204-
IS_NEOX_STYLE=is_neox_style,
205-
)
206-
return q, k
207-
208-
20934
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
21035
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
21136
)

0 commit comments

Comments
 (0)