@@ -281,18 +281,18 @@ def build_for_graph_capture(
281281class AscendAttentionBackendImpl (AttentionImpl ):
282282
283283 def __init__ (
284- self ,
285- num_heads : int ,
286- head_size : int ,
287- scale : float ,
288- num_kv_heads : int ,
289- alibi_slopes : Optional [List [float ]],
290- sliding_window : Optional [int ],
291- kv_cache_dtype : str ,
292- logits_soft_cap : Optional [float ],
293- attn_type : str ,
294- kv_sharing_target_layer_name : Optional [str ],
295- ** kwargs ,
284+ self ,
285+ num_heads : int ,
286+ head_size : int ,
287+ scale : float ,
288+ num_kv_heads : int ,
289+ alibi_slopes : Optional [List [float ]],
290+ sliding_window : Optional [int ],
291+ kv_cache_dtype : str ,
292+ logits_soft_cap : Optional [float ],
293+ attn_type : str ,
294+ kv_sharing_target_layer_name : Optional [str ],
295+ ** kwargs ,
296296 ) -> None :
297297 self .num_heads = num_heads
298298 self .head_size = head_size
@@ -313,11 +313,8 @@ def __init__(
313313 self .key_cache = None
314314 self .value_cache = None
315315
316- def _forward_prefill (self ,
317- query : torch .Tensor ,
318- key : torch .Tensor ,
319- value : torch .Tensor ,
320- attn_metadata : AscendMetadata ,
316+ def _forward_prefill (self , query : torch .Tensor , key : torch .Tensor ,
317+ value : torch .Tensor , attn_metadata : AscendMetadata ,
321318 output : torch .Tensor ):
322319 if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
323320 block_size = 128
@@ -365,18 +362,19 @@ def _forward_prefill(self,
365362 sparse_mode = 3 ,
366363 )
367364
368- attn_output = attn_output .view (num_tokens , self .num_heads , self .head_size )
365+ attn_output = attn_output .view (num_tokens , self .num_heads ,
366+ self .head_size )
369367 output [:num_tokens ] = attn_output [:num_tokens ]
370368 return output
371369
372370 def _forward_decode_only (
373- self ,
374- query : torch .Tensor ,
375- attn_metadata : AscendMetadata ,
376- output : torch .Tensor ,
371+ self ,
372+ query : torch .Tensor ,
373+ attn_metadata : AscendMetadata ,
374+ output : torch .Tensor ,
377375 ) -> torch .Tensor :
378376 if self .sliding_window is not None and attn_metadata .seq_lens .shape [
379- 0 ] == query .size (0 ):
377+ 0 ] == query .size (0 ):
380378 batch_size = attn_metadata .seq_lens .shape [0 ]
381379 block_size = 128
382380 query = query .view (batch_size , 1 , self .num_heads * self .head_size )
@@ -470,12 +468,12 @@ def _forward_decode_only(
470468 return output
471469
472470 def _forward_encode (
473- self ,
474- query : torch .Tensor ,
475- key : torch .Tensor ,
476- value : torch .Tensor ,
477- attn_metadata : AscendMetadata ,
478- output : torch .Tensor ,
471+ self ,
472+ query : torch .Tensor ,
473+ key : torch .Tensor ,
474+ value : torch .Tensor ,
475+ attn_metadata : AscendMetadata ,
476+ output : torch .Tensor ,
479477 ) -> torch .Tensor :
480478 cum_seq_len = attn_metadata .query_start_loc [1 :].tolist ()
481479 output = torch_npu .npu_fusion_attention (
@@ -495,15 +493,15 @@ def _forward_encode(
495493 return output
496494
497495 def forward (
498- self ,
499- layer : AttentionLayer ,
500- query : torch .Tensor ,
501- key : torch .Tensor ,
502- value : torch .Tensor ,
503- kv_cache : Tuple [torch .Tensor ],
504- attn_metadata : AscendMetadata ,
505- output : Optional [torch .Tensor ] = None ,
506- trace_flag : bool = True ,
496+ self ,
497+ layer : AttentionLayer ,
498+ query : torch .Tensor ,
499+ key : torch .Tensor ,
500+ value : torch .Tensor ,
501+ kv_cache : Tuple [torch .Tensor ],
502+ attn_metadata : AscendMetadata ,
503+ output : Optional [torch .Tensor ] = None ,
504+ trace_flag : bool = True ,
507505 ) -> torch .Tensor :
508506 """Forward pass with Ascend attention.
509507 Args:
@@ -546,7 +544,7 @@ def forward(
546544
547545 if attn_metadata is None :
548546 return output .view (num_tokens , self .hidden_size ).fill_ (0 )
549- # ori_output = output
547+
550548 if hasattr (layer , 'quant_method' ) and use_kv_cache_int8 :
551549 output = layer .quant_method .apply (layer , query , key , value ,
552550 kv_cache , attn_metadata ,
@@ -563,7 +561,8 @@ def forward(
563561
564562 if self .attn_type == AttentionType .ENCODER_ONLY :
565563 ori_output = output
566- output = self ._forward_encode (query , key , value , attn_metadata , output )
564+ output = self ._forward_encode (query , key , value , attn_metadata ,
565+ output )
567566 ori_output [:num_tokens , :, :] = output [:num_tokens , :, :]
568567 return ori_output .view (num_tokens , self .hidden_size )
569568
@@ -582,7 +581,8 @@ def forward(
582581 if attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
583582 output = self ._forward_decode_only (query , attn_metadata , output )
584583 else :
585- output = self ._forward_prefill (query , key , value , attn_metadata , output )
584+ output = self ._forward_prefill (query , key , value , attn_metadata ,
585+ output )
586586
587587 return output .view (num_tokens , self .hidden_size )
588588
0 commit comments