Skip to content

Commit 0250679

Browse files
author
weijinqian_v1
committed
[Refactor] add fia_v3 attention & remove other attention operator.
Signed-off-by: weijinqian_v1 <[email protected]>
1 parent 366c359 commit 0250679

File tree

1 file changed

+41
-41
lines changed

1 file changed

+41
-41
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -281,18 +281,18 @@ def build_for_graph_capture(
281281
class AscendAttentionBackendImpl(AttentionImpl):
282282

283283
def __init__(
284-
self,
285-
num_heads: int,
286-
head_size: int,
287-
scale: float,
288-
num_kv_heads: int,
289-
alibi_slopes: Optional[List[float]],
290-
sliding_window: Optional[int],
291-
kv_cache_dtype: str,
292-
logits_soft_cap: Optional[float],
293-
attn_type: str,
294-
kv_sharing_target_layer_name: Optional[str],
295-
**kwargs,
284+
self,
285+
num_heads: int,
286+
head_size: int,
287+
scale: float,
288+
num_kv_heads: int,
289+
alibi_slopes: Optional[List[float]],
290+
sliding_window: Optional[int],
291+
kv_cache_dtype: str,
292+
logits_soft_cap: Optional[float],
293+
attn_type: str,
294+
kv_sharing_target_layer_name: Optional[str],
295+
**kwargs,
296296
) -> None:
297297
self.num_heads = num_heads
298298
self.head_size = head_size
@@ -313,11 +313,8 @@ def __init__(
313313
self.key_cache = None
314314
self.value_cache = None
315315

316-
def _forward_prefill(self,
317-
query: torch.Tensor,
318-
key: torch.Tensor,
319-
value: torch.Tensor,
320-
attn_metadata: AscendMetadata,
316+
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
317+
value: torch.Tensor, attn_metadata: AscendMetadata,
321318
output: torch.Tensor):
322319
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
323320
block_size = 128
@@ -365,18 +362,19 @@ def _forward_prefill(self,
365362
sparse_mode=3,
366363
)
367364

368-
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
365+
attn_output = attn_output.view(num_tokens, self.num_heads,
366+
self.head_size)
369367
output[:num_tokens] = attn_output[:num_tokens]
370368
return output
371369

372370
def _forward_decode_only(
373-
self,
374-
query: torch.Tensor,
375-
attn_metadata: AscendMetadata,
376-
output: torch.Tensor,
371+
self,
372+
query: torch.Tensor,
373+
attn_metadata: AscendMetadata,
374+
output: torch.Tensor,
377375
) -> torch.Tensor:
378376
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
379-
0] == query.size(0):
377+
0] == query.size(0):
380378
batch_size = attn_metadata.seq_lens.shape[0]
381379
block_size = 128
382380
query = query.view(batch_size, 1, self.num_heads * self.head_size)
@@ -470,12 +468,12 @@ def _forward_decode_only(
470468
return output
471469

472470
def _forward_encode(
473-
self,
474-
query: torch.Tensor,
475-
key: torch.Tensor,
476-
value: torch.Tensor,
477-
attn_metadata: AscendMetadata,
478-
output: torch.Tensor,
471+
self,
472+
query: torch.Tensor,
473+
key: torch.Tensor,
474+
value: torch.Tensor,
475+
attn_metadata: AscendMetadata,
476+
output: torch.Tensor,
479477
) -> torch.Tensor:
480478
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
481479
output = torch_npu.npu_fusion_attention(
@@ -495,15 +493,15 @@ def _forward_encode(
495493
return output
496494

497495
def forward(
498-
self,
499-
layer: AttentionLayer,
500-
query: torch.Tensor,
501-
key: torch.Tensor,
502-
value: torch.Tensor,
503-
kv_cache: Tuple[torch.Tensor],
504-
attn_metadata: AscendMetadata,
505-
output: Optional[torch.Tensor] = None,
506-
trace_flag: bool = True,
496+
self,
497+
layer: AttentionLayer,
498+
query: torch.Tensor,
499+
key: torch.Tensor,
500+
value: torch.Tensor,
501+
kv_cache: Tuple[torch.Tensor],
502+
attn_metadata: AscendMetadata,
503+
output: Optional[torch.Tensor] = None,
504+
trace_flag: bool = True,
507505
) -> torch.Tensor:
508506
"""Forward pass with Ascend attention.
509507
Args:
@@ -546,7 +544,7 @@ def forward(
546544

547545
if attn_metadata is None:
548546
return output.view(num_tokens, self.hidden_size).fill_(0)
549-
# ori_output = output
547+
550548
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
551549
output = layer.quant_method.apply(layer, query, key, value,
552550
kv_cache, attn_metadata,
@@ -563,7 +561,8 @@ def forward(
563561

564562
if self.attn_type == AttentionType.ENCODER_ONLY:
565563
ori_output = output
566-
output = self._forward_encode(query, key, value, attn_metadata, output)
564+
output = self._forward_encode(query, key, value, attn_metadata,
565+
output)
567566
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
568567
return ori_output.view(num_tokens, self.hidden_size)
569568

@@ -582,7 +581,8 @@ def forward(
582581
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
583582
output = self._forward_decode_only(query, attn_metadata, output)
584583
else:
585-
output = self._forward_prefill(query, key, value, attn_metadata, output)
584+
output = self._forward_prefill(query, key, value, attn_metadata,
585+
output)
586586

587587
return output.view(num_tokens, self.hidden_size)
588588

0 commit comments

Comments
 (0)