Skip to content

Commit 5445fd7

Browse files
author
weijinqian_v1
committed
[Refactor] add fia_v3 attention & remove other attention operator.
Signed-off-by: weijinqian_v1 <[email protected]>
1 parent a919aef commit 5445fd7

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,9 @@ def forward(
505505
) -> torch.Tensor:
506506
"""Forward pass with Ascend attention.
507507
Args:
508-
query: shape = [batch_size, seq_len, num_heads * head_size]
509-
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
510-
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
508+
query: shape = [num_tokens, num_heads, head_size]
509+
key: shape = [num_tokens, num_kv_heads, head_size]
510+
value: shape = [num_tokens, num_kv_heads, head_size]
511511
kv_cache: shape = [key_cache, value_cache]
512512
key_cache = [num_blocks, block_size,
513513
num_kv_heads, head_size]
@@ -543,14 +543,14 @@ def forward(
543543
return output.view(num_tokens, self.hidden_size)
544544

545545
if attn_metadata is None:
546-
return output.view(num_tokens, self.hidden_size).fill_(0)
546+
return output.fill_(0)
547547

548548
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
549549
output = layer.quant_method.apply(layer, query, key, value,
550550
kv_cache, attn_metadata,
551551
self.attn_type, self.scale,
552552
output)
553-
return output.view(num_tokens, self.hidden_size)
553+
return output
554554

555555
# View q k v to BSH.
556556
query = query.view(-1, self.num_heads, self.head_size)
@@ -560,11 +560,9 @@ def forward(
560560
value = value.contiguous()
561561

562562
if self.attn_type == AttentionType.ENCODER_ONLY:
563-
ori_output = output
564563
output = self._forward_encode(query, key, value, attn_metadata,
565564
output)
566-
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
567-
return ori_output.view(num_tokens, self.hidden_size)
565+
return output
568566

569567
if len(kv_cache) > 1:
570568
if self.key_cache is None:
@@ -583,8 +581,7 @@ def forward(
583581
else:
584582
output = self._forward_prefill(query, key, value, attn_metadata,
585583
output)
586-
587-
return output.view(num_tokens, self.hidden_size)
584+
return output
588585

589586

590587
def unified_ascend_attention_with_output(

0 commit comments

Comments
 (0)