|
25 | 25 | DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, |
26 | 26 | YaRNScalingRotaryEmbedding) |
27 | 27 | from vllm.platforms import CpuArchEnum |
28 | | -from vllm.triton_utils import tl, triton |
29 | 28 |
|
30 | 29 | from vllm_ascend.platform import NPUPlatform |
31 | 30 | from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, |
32 | 31 | get_ascend_device_type) |
33 | 32 |
|
34 | 33 |
|
35 | | -@triton.jit |
36 | | -def _triton_rope( |
37 | | - q_ptr, |
38 | | - q_row_stride, |
39 | | - k_ptr, |
40 | | - k_row_stride, |
41 | | - cos, |
42 | | - cos_row_stride, |
43 | | - sin, |
44 | | - sin_row_stride, |
45 | | - num_tokens, |
46 | | - n_qh: tl.constexpr, |
47 | | - n_kh: tl.constexpr, |
48 | | - hd: tl.constexpr, |
49 | | - rope_dim: tl.constexpr, |
50 | | - pad_n_qh: tl.constexpr, |
51 | | - pad_n_kh: tl.constexpr, |
52 | | - pad_rope_dim: tl.constexpr, |
53 | | - BLOCK_SIZE: tl.constexpr, |
54 | | - IS_NEOX_STYLE: tl.constexpr, |
55 | | -): |
56 | | - """ |
57 | | - This triton kernel applies rotary embedding on q and k. |
58 | | - It supports rope_dim != head_dim scenario. |
59 | | - It supports both neox style and non-neox style rope computation. |
60 | | - |
61 | | - Input tensor layout assumptions: |
62 | | - |
63 | | - q size: (num_tokens, num_q_heads, head_dim) |
64 | | - q stride: (num_q_heads * head_dim, head_dim, 1) |
65 | | - k size: (num_tokens, num_kv_heads, head_dim) |
66 | | - k stride: (num_kv_heads * head_dim, head_dim, 1) |
67 | | - cos/sin size: (num_tokens, rope_dim/2) |
68 | | - cos/sin stride: (rope_dim/2, 1) |
69 | | - |
70 | | - Different compute pattern of IS_NEOX_STYLE: |
71 | | -
|
72 | | - if IS_NEOX_STYLE: |
73 | | - x1, x2 = torch.chunk(x, 2, dim=-1) |
74 | | - else: |
75 | | - x1 = x[..., ::2] |
76 | | - x2 = x[..., 1::2] |
77 | | - o1 = x1 * cos - x2 * sin |
78 | | - o2 = x2 * cos + x1 * sin |
79 | | - if IS_NEOX_STYLE: |
80 | | - return torch.cat((o1, o2), dim=-1) |
81 | | - else: |
82 | | - return torch.stack((o1, o2), dim=-1).flatten(-2) |
83 | | - """ |
84 | | - pid = tl.program_id(0).to(tl.int64) |
85 | | - row_block_size = tl.num_programs(0) |
86 | | - |
87 | | - for row_idx in tl.range(pid, num_tokens, row_block_size): |
88 | | - q_start_ptr = q_ptr + row_idx * q_row_stride |
89 | | - k_start_ptr = k_ptr + row_idx * k_row_stride |
90 | | - |
91 | | - # #################################################################### |
92 | | - # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position |
93 | | - # m of this program instance |
94 | | - # #################################################################### |
95 | | - cos_start_ptr = cos + row_idx * cos_row_stride |
96 | | - sin_start_ptr = sin + row_idx * sin_row_stride |
97 | | - |
98 | | - cos_offsets = tl.arange(0, pad_rope_dim // 2) |
99 | | - cos_mask = cos_offsets < (rope_dim // 2) |
100 | | - cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) |
101 | | - sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) |
102 | | - |
103 | | - # #################################################################### |
104 | | - # Load the left and right half of q and k for the current |
105 | | - # program instance (i.e. for the current token) separately |
106 | | - # #################################################################### |
107 | | - # left half of the head |
108 | | - if IS_NEOX_STYLE: |
109 | | - first_half_q_offsets = tl.arange(0, |
110 | | - pad_n_qh)[:, None] * hd + tl.arange( |
111 | | - 0, pad_rope_dim // 2)[None, :] |
112 | | - first_half_k_offsets = tl.arange(0, |
113 | | - pad_n_kh)[:, None] * hd + tl.arange( |
114 | | - 0, pad_rope_dim // 2)[None, :] |
115 | | - else: |
116 | | - first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + ( |
117 | | - 2 * tl.arange(0, pad_rope_dim // 2)[None, :]) |
118 | | - first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + ( |
119 | | - 2 * tl.arange(0, pad_rope_dim // 2)[None, :]) |
120 | | - |
121 | | - first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( |
122 | | - 0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)) |
123 | | - first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( |
124 | | - 0, pad_rope_dim // 2)[None, :] < (rope_dim // 2)) |
125 | | - q_tile_1 = tl.load(q_start_ptr + first_half_q_offsets, |
126 | | - mask=first_q_mask, |
127 | | - other=0).to(sin_row.dtype) |
128 | | - k_tile_1 = tl.load(k_start_ptr + first_half_k_offsets, |
129 | | - mask=first_k_mask, |
130 | | - other=0).to(sin_row.dtype) |
131 | | - |
132 | | - # right half of the head |
133 | | - if IS_NEOX_STYLE: |
134 | | - second_half_q_offsets = first_half_q_offsets + (rope_dim // 2) |
135 | | - second_half_k_offsets = first_half_k_offsets + (rope_dim // 2) |
136 | | - else: |
137 | | - second_half_q_offsets = first_half_q_offsets + 1 |
138 | | - second_half_k_offsets = first_half_k_offsets + 1 |
139 | | - second_q_mask = first_q_mask |
140 | | - second_k_mask = first_k_mask |
141 | | - q_tile_2 = tl.load(q_start_ptr + second_half_q_offsets, |
142 | | - mask=second_q_mask, |
143 | | - other=0).to(sin_row.dtype) |
144 | | - k_tile_2 = tl.load(k_start_ptr + second_half_k_offsets, |
145 | | - mask=second_k_mask, |
146 | | - other=0).to(sin_row.dtype) |
147 | | - |
148 | | - # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] |
149 | | - new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row |
150 | | - tl.store(q_start_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) |
151 | | - new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row |
152 | | - tl.store(q_start_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) |
153 | | - |
154 | | - new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row |
155 | | - tl.store(k_start_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) |
156 | | - new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row |
157 | | - tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) |
158 | | - |
159 | | - |
160 | | -def rope_forward_triton(q, |
161 | | - k, |
162 | | - cos, |
163 | | - sin, |
164 | | - rope_dim: int = -1, |
165 | | - is_neox_style: bool = True): |
166 | | - if not q.is_contiguous(): |
167 | | - q = q.contiguous() |
168 | | - if not k.is_contiguous(): |
169 | | - k = k.contiguous() |
170 | | - |
171 | | - num_tokens, n_q_head, head_dim = q.shape |
172 | | - n_kv_head = k.shape[1] |
173 | | - cos = cos.view(num_tokens, -1) |
174 | | - sin = sin.view(num_tokens, -1) |
175 | | - if rope_dim == -1: |
176 | | - # If rope_dim is not specified, we assume that input cos/sin is not |
177 | | - # duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2 |
178 | | - rope_dim = cos.shape[-1] * 2 |
179 | | - assert rope_dim <= head_dim |
180 | | - pad_rope_dim = triton.next_power_of_2(rope_dim) |
181 | | - pad_n_q_head = triton.next_power_of_2(n_q_head) |
182 | | - pad_n_kv_head = triton.next_power_of_2(n_kv_head) |
183 | | - BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) |
184 | | - n_row = min(num_tokens, NUM_VECTORCORE) |
185 | | - |
186 | | - _triton_rope[(n_row, )]( |
187 | | - q, |
188 | | - q.stride(0), |
189 | | - k, |
190 | | - k.stride(0), |
191 | | - cos, |
192 | | - cos.stride(0), |
193 | | - sin, |
194 | | - sin.stride(0), |
195 | | - num_tokens, |
196 | | - n_q_head, |
197 | | - n_kv_head, |
198 | | - head_dim, |
199 | | - rope_dim, |
200 | | - pad_n_q_head, |
201 | | - pad_n_kv_head, |
202 | | - pad_rope_dim, |
203 | | - BLOCK_SIZE=BLOCK_SIZE, |
204 | | - IS_NEOX_STYLE=is_neox_style, |
205 | | - ) |
206 | | - return q, k |
207 | | - |
208 | | - |
209 | 34 | def _custom_rotary_embedding_enabled(query, neox_style, head_size): |
210 | 35 | return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( |
211 | 36 | ) |
|
0 commit comments