@@ -505,9 +505,9 @@ def forward(
505505 ) -> torch .Tensor :
506506 """Forward pass with Ascend attention.
507507 Args:
508- query: shape = [batch_size, seq_len, num_heads * head_size]
509- key: shape = [batch_size, seq_len, num_kv_heads * head_size]
510- value: shape = [batch_size, seq_len, num_kv_heads * head_size]
508+ query: shape = [num_tokens, num_heads, head_size]
509+ key: shape = [num_tokens, num_kv_heads, head_size]
510+ value: shape = [num_tokens, num_kv_heads, head_size]
511511 kv_cache: shape = [key_cache, value_cache]
512512 key_cache = [num_blocks, block_size,
513513 num_kv_heads, head_size]
@@ -543,14 +543,14 @@ def forward(
543543 return output .view (num_tokens , self .hidden_size )
544544
545545 if attn_metadata is None :
546- return output .view ( num_tokens , self . hidden_size ). fill_ (0 )
546+ return output .fill_ (0 )
547547
548548 if hasattr (layer , 'quant_method' ) and use_kv_cache_int8 :
549549 output = layer .quant_method .apply (layer , query , key , value ,
550550 kv_cache , attn_metadata ,
551551 self .attn_type , self .scale ,
552552 output )
553- return output . view ( num_tokens , self . hidden_size )
553+ return output
554554
555555 # View q k v to BSH.
556556 query = query .view (- 1 , self .num_heads , self .head_size )
@@ -560,11 +560,9 @@ def forward(
560560 value = value .contiguous ()
561561
562562 if self .attn_type == AttentionType .ENCODER_ONLY :
563- ori_output = output
564563 output = self ._forward_encode (query , key , value , attn_metadata ,
565564 output )
566- ori_output [:num_tokens , :, :] = output [:num_tokens , :, :]
567- return ori_output .view (num_tokens , self .hidden_size )
565+ return output
568566
569567 if len (kv_cache ) > 1 :
570568 if self .key_cache is None :
@@ -583,8 +581,7 @@ def forward(
583581 else :
584582 output = self ._forward_prefill (query , key , value , attn_metadata ,
585583 output )
586-
587- return output .view (num_tokens , self .hidden_size )
584+ return output
588585
589586
590587def unified_ascend_attention_with_output (
0 commit comments