|
15 | 15 | # See the License for the specific language governing permissions and |
16 | 16 | # limitations under the License. |
17 | 17 |
|
18 | | -from typing import List, Optional, Tuple |
| 18 | +from typing import Optional |
19 | 19 |
|
20 | 20 | import torch |
21 | | -from vllm.model_executor.layers.linear import ColumnParallelLinear |
22 | 21 |
|
23 | 22 |
|
24 | 23 | # Implementation of vanilla chunked prefill, should be removed after the kernel is ready for |
@@ -133,177 +132,3 @@ def vanilla_chunked_prefill( |
133 | 132 | head_dim]).to(output.dtype)) |
134 | 133 | output.copy_(attn_output) |
135 | 134 | return attn_output |
136 | | - |
137 | | - |
138 | | -def vanilla_chunked_prefill_mla( |
139 | | - output: torch.Tensor, # (num_tokens, num_heads, v_head_dim) |
140 | | - query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim) |
141 | | - kv_cache: Tuple[ |
142 | | - torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv) |
143 | | - block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq) |
144 | | - query_lens: torch.Tensor, # (batch_size) |
145 | | - context_lens: torch.Tensor, # (batch_size) |
146 | | - kv_b_proj: ColumnParallelLinear, # () |
147 | | - max_query_len: int, |
148 | | - max_context_len: int, |
149 | | - nope_dim: int, |
150 | | - rope_dim: int, |
151 | | - v_head_dim: int, |
152 | | - scale: float, |
153 | | - alibi_slopes: Optional[torch.Tensor], |
154 | | - causal: bool = True) -> None: |
155 | | - batch_size = block_tables.size(0) |
156 | | - assert len(kv_cache) > 1 |
157 | | - assert query_lens.size(0) == batch_size |
158 | | - num_heads = query.size(1) |
159 | | - nope_cache = kv_cache[0] |
160 | | - rope_cache = kv_cache[1] |
161 | | - block_size = nope_cache.size(1) |
162 | | - latent_kv_dim = nope_cache.size(-1) |
163 | | - max_num_blocks_per_seq = block_tables.size(1) |
164 | | - batch_size = query_lens.size(0) |
165 | | - nope_cache = nope_cache.squeeze() |
166 | | - # select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe |
167 | | - # cached_kv_c: [batch_size, max_context_len, latent_kv] |
168 | | - # cached_k_pe: [batch_size, max_context_len, rope_dim] |
169 | | - cache_kv_c = nope_cache[block_tables].view( |
170 | | - batch_size, max_num_blocks_per_seq * block_size, |
171 | | - latent_kv_dim)[:, :max_context_len, :] |
172 | | - cache_k_pe = rope_cache[block_tables].view( |
173 | | - batch_size, max_num_blocks_per_seq * block_size, |
174 | | - rope_dim)[:, :max_context_len, :] |
175 | | - # get k_rope and v |
176 | | - # k_nope: [batch_size, max_context_len, num_heads, nope_dim] |
177 | | - # value: [batch_size, max_context_len, num_heads, v_head_dim] |
178 | | - k_nope, value = kv_b_proj(cache_kv_c)[0].view( |
179 | | - batch_size, max_context_len, num_heads, |
180 | | - nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1) |
181 | | - # key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim] |
182 | | - key = torch.cat( |
183 | | - [k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)], |
184 | | - dim=-1) |
185 | | - |
186 | | - context_lens = context_lens.view(-1, 1).to("npu") |
187 | | - query_lens = query_lens.view(-1, 1).to("npu") |
188 | | - seq_diff = context_lens - query_lens |
189 | | - |
190 | | - q_idx_mask = (torch.arange(0, max_query_len, |
191 | | - device="npu").view(1, -1).repeat(batch_size, 1)) |
192 | | - kv_c_idx_mask = (torch.arange(0, max_context_len, |
193 | | - device="npu").view(1, |
194 | | - -1).repeat(batch_size, 1)) |
195 | | - kv_c_mask = kv_c_idx_mask < context_lens |
196 | | - q_mask = q_idx_mask < query_lens |
197 | | - |
198 | | - # calculate idx for causal mask of query [batch, max_seqlen_q] |
199 | | - causal_mask_idx = (q_idx_mask + seq_diff)[q_mask] |
200 | | - |
201 | | - # generate causal mask [batch, max_seqlen_q, max_seqlen_k] |
202 | | - tril_mask = torch.tril( |
203 | | - torch.ones(max_context_len, max_context_len, device="npu")) |
204 | | - tril_mask[tril_mask == 0] = float("-inf") |
205 | | - tril_mask[tril_mask == 1] = 0 |
206 | | - causal_mask = tril_mask[causal_mask_idx] |
207 | | - causal_mask_padding = torch.empty( |
208 | | - [batch_size, max_query_len, max_context_len], |
209 | | - device="npu").fill_(float("-inf")) |
210 | | - causal_mask_padding[q_mask] = causal_mask |
211 | | - # to [batch, num_heads, max_seqlen_q, max_seqlen_k] |
212 | | - causal_mask_padding = causal_mask_padding.unsqueeze(1) |
213 | | - |
214 | | - pad_q = torch.zeros( |
215 | | - [batch_size, max_query_len, num_heads, rope_dim + nope_dim], |
216 | | - device="npu", |
217 | | - dtype=query.dtype, |
218 | | - ) |
219 | | - pad_k = torch.zeros( |
220 | | - [batch_size, max_context_len, num_heads, rope_dim + nope_dim], |
221 | | - device="npu", |
222 | | - dtype=key.dtype, |
223 | | - ) |
224 | | - pad_v = torch.zeros( |
225 | | - [batch_size, max_context_len, num_heads, v_head_dim], |
226 | | - device="npu", |
227 | | - dtype=value.dtype, |
228 | | - ) |
229 | | - num_query = torch.sum(q_mask).item() |
230 | | - num_add_query = num_query - query.size(0) |
231 | | - # mtp will come in |
232 | | - if num_add_query > 0: |
233 | | - add_query_size = query.size() |
234 | | - add_query_size = list(add_query_size) |
235 | | - add_query_size[0] = num_add_query |
236 | | - pad_tensor = torch.zeros(add_query_size, |
237 | | - dtype=query.dtype, |
238 | | - device=query.device) |
239 | | - query = torch.cat([query, pad_tensor], dim=0) |
240 | | - pad_q[q_mask] = query |
241 | | - pad_k[kv_c_mask] = key[kv_c_mask] |
242 | | - pad_v[kv_c_mask] = value[kv_c_mask] |
243 | | - |
244 | | - pad_q = pad_q.permute(0, 2, 1, 3) |
245 | | - pad_k = pad_k.permute(0, 2, 1, 3) |
246 | | - pad_v = pad_v.permute(0, 2, 1, 3) |
247 | | - attn_mask = torch.empty([batch_size, 1, 1, max_context_len], |
248 | | - device="npu").fill_(float("-inf")) |
249 | | - attn_mask[:, :, :, :max_context_len].masked_fill_( |
250 | | - kv_c_mask[:, None, None, :], 0) |
251 | | - # [b, h, f, t] |
252 | | - attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k) |
253 | | - attn_weights *= scale |
254 | | - attn_mask = attn_mask.float() |
255 | | - attn_weights = attn_weights + attn_mask |
256 | | - if causal: |
257 | | - attn_weights = attn_weights + causal_mask_padding |
258 | | - |
259 | | - attn_weights = torch.softmax(attn_weights, dim=-1) |
260 | | - attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float()) |
261 | | - attn_output = attn_output.permute(0, 2, 1, 3) |
262 | | - |
263 | | - attn_output = (attn_output[q_mask].view([-1, num_heads, |
264 | | - v_head_dim]).to(output.dtype)) |
265 | | - attn_output = attn_output.view_as(output) |
266 | | - output.copy_(attn_output) |
267 | | - return attn_output |
268 | | - |
269 | | - |
270 | | -def vanilla_decode_mla( |
271 | | - query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim] |
272 | | - key_cache: torch. |
273 | | - Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim] |
274 | | - num_kv_heads: int, |
275 | | - num_heads: int, |
276 | | - scale: float, |
277 | | - block_table: torch.Tensor, # [batch_size, max_block_size] |
278 | | - context_lens: List[int], |
279 | | - mla_vhead_size: int, |
280 | | - rope_dim: int, |
281 | | - output: torch.Tensor): |
282 | | - batch_size = block_table.size()[0] |
283 | | - max_block_size = block_table.size()[1] |
284 | | - reduce_dim = key_cache.size()[-1] |
285 | | - block_size = key_cache.size()[1] |
286 | | - latent_dim = reduce_dim - rope_dim |
287 | | - kv_c_and_pe = key_cache[block_table].view( |
288 | | - [batch_size, max_block_size * block_size, num_kv_heads, reduce_dim]) |
289 | | - max_context_len = max(context_lens) |
290 | | - context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1) |
291 | | - # [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim] |
292 | | - # since the kv head is 1 in deepseek, we use expand here for perf |
293 | | - kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand( |
294 | | - -1, -1, num_heads, 1) |
295 | | - kv_c = kv_c_and_pe[..., :latent_dim] |
296 | | - kv_idx_mask = (torch.arange(0, max_context_len, |
297 | | - device="npu").view(1, |
298 | | - -1).repeat(batch_size, 1)) |
299 | | - # [batch_size, max_context_len] |
300 | | - kv_idx_mask = kv_idx_mask < context_lens |
301 | | - query = query.unsqueeze(1) |
302 | | - attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe) |
303 | | - attn_weights *= scale |
304 | | - attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float() |
305 | | - attn_weights = torch.softmax(attn_weights, dim=-1) |
306 | | - attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights, |
307 | | - kv_c.float()).view(-1, num_heads, latent_dim) |
308 | | - output.copy_(attn_output) |
309 | | - return output |
0 commit comments