Skip to content

Commit 6be321b

Browse files
authored
remove useless code (#3685)
### What this PR does / why we need it? `vanilla_chunked_prefill_mla` and `vanilla_decode_mla` is unused, so remove it. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: zzzzwwjj <[email protected]>
1 parent cd58a64 commit 6be321b

File tree

1 file changed

+1
-176
lines changed

1 file changed

+1
-176
lines changed

vllm_ascend/ops/attention.py

Lines changed: 1 addition & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from typing import List, Optional, Tuple
18+
from typing import Optional
1919

2020
import torch
21-
from vllm.model_executor.layers.linear import ColumnParallelLinear
2221

2322

2423
# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for
@@ -133,177 +132,3 @@ def vanilla_chunked_prefill(
133132
head_dim]).to(output.dtype))
134133
output.copy_(attn_output)
135134
return attn_output
136-
137-
138-
def vanilla_chunked_prefill_mla(
139-
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
140-
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
141-
kv_cache: Tuple[
142-
torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv)
143-
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
144-
query_lens: torch.Tensor, # (batch_size)
145-
context_lens: torch.Tensor, # (batch_size)
146-
kv_b_proj: ColumnParallelLinear, # ()
147-
max_query_len: int,
148-
max_context_len: int,
149-
nope_dim: int,
150-
rope_dim: int,
151-
v_head_dim: int,
152-
scale: float,
153-
alibi_slopes: Optional[torch.Tensor],
154-
causal: bool = True) -> None:
155-
batch_size = block_tables.size(0)
156-
assert len(kv_cache) > 1
157-
assert query_lens.size(0) == batch_size
158-
num_heads = query.size(1)
159-
nope_cache = kv_cache[0]
160-
rope_cache = kv_cache[1]
161-
block_size = nope_cache.size(1)
162-
latent_kv_dim = nope_cache.size(-1)
163-
max_num_blocks_per_seq = block_tables.size(1)
164-
batch_size = query_lens.size(0)
165-
nope_cache = nope_cache.squeeze()
166-
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe
167-
# cached_kv_c: [batch_size, max_context_len, latent_kv]
168-
# cached_k_pe: [batch_size, max_context_len, rope_dim]
169-
cache_kv_c = nope_cache[block_tables].view(
170-
batch_size, max_num_blocks_per_seq * block_size,
171-
latent_kv_dim)[:, :max_context_len, :]
172-
cache_k_pe = rope_cache[block_tables].view(
173-
batch_size, max_num_blocks_per_seq * block_size,
174-
rope_dim)[:, :max_context_len, :]
175-
# get k_rope and v
176-
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
177-
# value: [batch_size, max_context_len, num_heads, v_head_dim]
178-
k_nope, value = kv_b_proj(cache_kv_c)[0].view(
179-
batch_size, max_context_len, num_heads,
180-
nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1)
181-
# key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim]
182-
key = torch.cat(
183-
[k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)],
184-
dim=-1)
185-
186-
context_lens = context_lens.view(-1, 1).to("npu")
187-
query_lens = query_lens.view(-1, 1).to("npu")
188-
seq_diff = context_lens - query_lens
189-
190-
q_idx_mask = (torch.arange(0, max_query_len,
191-
device="npu").view(1, -1).repeat(batch_size, 1))
192-
kv_c_idx_mask = (torch.arange(0, max_context_len,
193-
device="npu").view(1,
194-
-1).repeat(batch_size, 1))
195-
kv_c_mask = kv_c_idx_mask < context_lens
196-
q_mask = q_idx_mask < query_lens
197-
198-
# calculate idx for causal mask of query [batch, max_seqlen_q]
199-
causal_mask_idx = (q_idx_mask + seq_diff)[q_mask]
200-
201-
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
202-
tril_mask = torch.tril(
203-
torch.ones(max_context_len, max_context_len, device="npu"))
204-
tril_mask[tril_mask == 0] = float("-inf")
205-
tril_mask[tril_mask == 1] = 0
206-
causal_mask = tril_mask[causal_mask_idx]
207-
causal_mask_padding = torch.empty(
208-
[batch_size, max_query_len, max_context_len],
209-
device="npu").fill_(float("-inf"))
210-
causal_mask_padding[q_mask] = causal_mask
211-
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
212-
causal_mask_padding = causal_mask_padding.unsqueeze(1)
213-
214-
pad_q = torch.zeros(
215-
[batch_size, max_query_len, num_heads, rope_dim + nope_dim],
216-
device="npu",
217-
dtype=query.dtype,
218-
)
219-
pad_k = torch.zeros(
220-
[batch_size, max_context_len, num_heads, rope_dim + nope_dim],
221-
device="npu",
222-
dtype=key.dtype,
223-
)
224-
pad_v = torch.zeros(
225-
[batch_size, max_context_len, num_heads, v_head_dim],
226-
device="npu",
227-
dtype=value.dtype,
228-
)
229-
num_query = torch.sum(q_mask).item()
230-
num_add_query = num_query - query.size(0)
231-
# mtp will come in
232-
if num_add_query > 0:
233-
add_query_size = query.size()
234-
add_query_size = list(add_query_size)
235-
add_query_size[0] = num_add_query
236-
pad_tensor = torch.zeros(add_query_size,
237-
dtype=query.dtype,
238-
device=query.device)
239-
query = torch.cat([query, pad_tensor], dim=0)
240-
pad_q[q_mask] = query
241-
pad_k[kv_c_mask] = key[kv_c_mask]
242-
pad_v[kv_c_mask] = value[kv_c_mask]
243-
244-
pad_q = pad_q.permute(0, 2, 1, 3)
245-
pad_k = pad_k.permute(0, 2, 1, 3)
246-
pad_v = pad_v.permute(0, 2, 1, 3)
247-
attn_mask = torch.empty([batch_size, 1, 1, max_context_len],
248-
device="npu").fill_(float("-inf"))
249-
attn_mask[:, :, :, :max_context_len].masked_fill_(
250-
kv_c_mask[:, None, None, :], 0)
251-
# [b, h, f, t]
252-
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
253-
attn_weights *= scale
254-
attn_mask = attn_mask.float()
255-
attn_weights = attn_weights + attn_mask
256-
if causal:
257-
attn_weights = attn_weights + causal_mask_padding
258-
259-
attn_weights = torch.softmax(attn_weights, dim=-1)
260-
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
261-
attn_output = attn_output.permute(0, 2, 1, 3)
262-
263-
attn_output = (attn_output[q_mask].view([-1, num_heads,
264-
v_head_dim]).to(output.dtype))
265-
attn_output = attn_output.view_as(output)
266-
output.copy_(attn_output)
267-
return attn_output
268-
269-
270-
def vanilla_decode_mla(
271-
query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim]
272-
key_cache: torch.
273-
Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim]
274-
num_kv_heads: int,
275-
num_heads: int,
276-
scale: float,
277-
block_table: torch.Tensor, # [batch_size, max_block_size]
278-
context_lens: List[int],
279-
mla_vhead_size: int,
280-
rope_dim: int,
281-
output: torch.Tensor):
282-
batch_size = block_table.size()[0]
283-
max_block_size = block_table.size()[1]
284-
reduce_dim = key_cache.size()[-1]
285-
block_size = key_cache.size()[1]
286-
latent_dim = reduce_dim - rope_dim
287-
kv_c_and_pe = key_cache[block_table].view(
288-
[batch_size, max_block_size * block_size, num_kv_heads, reduce_dim])
289-
max_context_len = max(context_lens)
290-
context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1)
291-
# [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim]
292-
# since the kv head is 1 in deepseek, we use expand here for perf
293-
kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand(
294-
-1, -1, num_heads, 1)
295-
kv_c = kv_c_and_pe[..., :latent_dim]
296-
kv_idx_mask = (torch.arange(0, max_context_len,
297-
device="npu").view(1,
298-
-1).repeat(batch_size, 1))
299-
# [batch_size, max_context_len]
300-
kv_idx_mask = kv_idx_mask < context_lens
301-
query = query.unsqueeze(1)
302-
attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe)
303-
attn_weights *= scale
304-
attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float()
305-
attn_weights = torch.softmax(attn_weights, dim=-1)
306-
attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights,
307-
kv_c.float()).view(-1, num_heads, latent_dim)
308-
output.copy_(attn_output)
309-
return output

0 commit comments

Comments
 (0)