@@ -329,6 +329,82 @@ def __init__(
329329 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
330330 self .key_cache = None
331331 self .value_cache = None
332+
333+ def full_graph_attention (self ,
334+ query : torch .Tensor ,
335+ key : torch .Tensor ,
336+ value : torch .Tensor ,
337+ kv_cache : Tuple [torch .Tensor ],
338+ attn_metadata : AscendMetadata ,
339+ output : torch .Tensor ,
340+ num_tokens = 0 ):
341+ if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
342+ block_size = 128
343+ block_table = None
344+ actual_seq_lengths_kv = attn_metadata .actual_seq_lengths_q
345+ if is_310p ():
346+ # align q k v output tensors
347+ query = aligned_16 (query )
348+ key = aligned_16 (key )
349+ value = aligned_16 (value )
350+ output = aligned_16 (output )
351+ # do reformat in case of broadcasted tensors
352+ mask = mask .repeat (attn_metadata .seq_lens .size (0 ), 1 , 1 , 1 )
353+ mask = torch_npu .npu_format_cast (mask .contiguous (),
354+ ACL_FORMAT_FRACTAL_NZ )
355+ elif attn_metadata .attn_state == \
356+ AscendAttentionState .PrefillCacheHit :
357+ batch_size = attn_metadata .query_lens .shape [0 ]
358+ block_table = attn_metadata .block_tables [:batch_size , :]
359+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
360+ key = self .key_cache .view ( # type: ignore
361+ num_block , block_size , - 1 )
362+ value = self .value_cache .view ( # type: ignore
363+ num_block , block_size , - 1 )
364+ actual_seq_lengths_kv = attn_metadata .seq_lens_list
365+ # Normal V1 situation.
366+ else :
367+ if is_310p ():
368+ # Do reformat in case of broadcasted tensors.
369+ attn_metadata .attn_mask = \
370+ torch_npu .npu_format_cast (attn_metadata .attn_mask .contiguous (),
371+ ACL_FORMAT_FRACTAL_NZ )
372+ attn_metadata .seq_lens = \
373+ attn_metadata .seq_lens .to (device = query .device )
374+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
375+ key = self .key_cache .view ( # type: ignore
376+ num_block , block_size , - 1 )
377+ value = self .value_cache .view ( # type: ignore
378+ num_block , block_size , - 1 )
379+ block_table = attn_metadata .block_tables
380+ actual_seq_lengths_kv = attn_metadata .seq_lens_list
381+
382+ num_tokens = attn_metadata .actual_seq_lengths_q [- 1 ]
383+ query = query [:num_tokens ]
384+ # Prepare tensors for attention output
385+ # TODO: Refactor this to step-level instead of layer-level
386+
387+ # Get workspace from cache or calculate it if not present.
388+ output , _ = torch_npu .npu_fused_infer_attention_score (
389+ query = query ,
390+ key = key ,
391+ value = value ,
392+ atten_mask = attn_metadata .attn_mask ,
393+ block_table = block_table ,
394+ input_layout = "TND" ,
395+ block_size = block_size ,
396+ actual_seq_lengths = attn_metadata .actual_seq_lengths_q ,
397+ actual_seq_lengths_kv = actual_seq_lengths_kv ,
398+ num_key_value_heads = self .num_kv_heads ,
399+ num_heads = self .num_heads ,
400+ scale = self .scale ,
401+ sparse_mode = 3 ,
402+ )
403+
404+ output = output .view (num_tokens , self .num_heads , self .head_size )
405+
406+ return output , num_tokens
407+
332408
333409 def _forward_prefill_no_cache (
334410 self ,
@@ -662,24 +738,14 @@ def forward(
662738 )
663739 output = attn_out [0 ]
664740 # V0-Style scheduler situation.
665- elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
666- output = self ._forward_prefill_no_cache (
667- query , key , value , attn_metadata , output , num_tokens )
668- elif attn_metadata .attn_state == \
669- AscendAttentionState .PrefillCacheHit :
670- output = self ._forward_prefill_cache_hit (
671- query , attn_metadata , output )
672741 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
673742 output = self ._forward_decode_only (query , attn_metadata ,
674743 output )
675744 # Normal V1 situation.
676745 else :
677- # npu_fused_infer_attention_score does not support cases
678- # where query.shape[0] != attn_metadata.query_start_loc[-1].
679- # Thus we need unpad it here.
680- num_tokens = attn_metadata .query_start_loc [- 1 ]
681- query = query [:num_tokens ]
682- output = self ._forward_v1_style (query , attn_metadata , output )
746+ intermediate_output , query_num_tokens = self .full_graph_attention (
747+ query , key , value , kv_cache , attn_metadata , output )
748+ output [:query_num_tokens ] = intermediate_output [:query_num_tokens ]
683749
684750 # to make in-place change to the output tensor
685751 if hasattr (layer , 'quant_method' ) and use_kv_cache_int8 :
0 commit comments