@@ -441,6 +441,110 @@ def __init__(
441441 ) if self .dcp_size > 1 else 0
442442 self .dcp_group = get_dcp_group (
443443 ).device_group if self .dcp_size > 1 else None
444+
445+ def full_graph_attention (self ,
446+ query : torch .Tensor ,
447+ key : torch .Tensor ,
448+ value : torch .Tensor ,
449+ attn_metadata : AscendMetadata ,
450+ block_size : int ,
451+ output : Optional [torch .Tensor ] = None ,
452+ num_tokens = 0 ,):
453+ num_tokens = query .shape [0 ]
454+ forward_context : ForwardContext = get_forward_context ()
455+ if forward_context .capturing :
456+ graph_params = get_graph_params ()
457+ query_start_loc = attn_metadata .actual_seq_lengths_q
458+ seq_lens = attn_metadata .seq_lens_lis
459+ # Prepare tensors for attention output
460+ # TODO: Refactor this to step-level instead of layer-level
461+
462+ # Get workspace from cache or calculate it if not present.
463+ workspace = graph_params .workspaces .get (num_tokens )
464+ softmax_lse = torch .empty (num_tokens ,
465+ dtype = query .dtype ,
466+ device = query .device )
467+ if workspace is None :
468+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
469+ query = query ,
470+ key = key ,
471+ value = value ,
472+ atten_mask = attn_metadata .attn_mask ,
473+ block_table = attn_metadata .block_tables ,
474+ input_layout = "TND" ,
475+ block_size = block_size ,
476+ actual_seq_lengths = query_start_loc ,
477+ actual_seq_lengths_kv = seq_lens ,
478+ num_key_value_heads = self .num_kv_heads ,
479+ num_heads = self .num_heads ,
480+ sparse_mode = 3 ,
481+ scale = self .scale ,)
482+ graph_params .workspaces [num_tokens ] = weak_ref_tensors (workspace )
483+
484+ # Handle graph capturing mode
485+ stream = torch_npu .npu .current_stream ()
486+
487+ event = torch .npu .ExternalEvent ()
488+ event .wait (stream )
489+ event .reset (stream )
490+ graph_params .events [num_tokens ].append (event )
491+ graph_params .attn_params [num_tokens ].append ((
492+ weak_ref_tensors (query ),
493+ weak_ref_tensors (key ),
494+ weak_ref_tensors (value ),
495+ weak_ref_tensors (attn_metadata .block_tables ),
496+ weak_ref_tensors (attn_metadata .attn_mask ),
497+ block_size ,
498+ seq_lens ,
499+ query_start_loc ,
500+ self .num_kv_heads ,
501+ self .num_heads ,
502+ self .scale ,
503+ weak_ref_tensors (output ),
504+ weak_ref_tensors (softmax_lse )
505+ ))
506+
507+ torch .npu .graph_task_group_begin (stream )
508+ torch_npu .npu_fused_infer_attention_score .out (
509+ query = query ,
510+ key = key ,
511+ value = value ,
512+ atten_mask = attn_metadata .attn_mask ,
513+ block_table = attn_metadata .block_tables ,
514+ input_layout = "TND" ,
515+ block_size = block_size ,
516+ actual_seq_lengths = query_start_loc ,
517+ actual_seq_lengths_kv = seq_lens ,
518+ num_key_value_heads = self .num_kv_heads ,
519+ num_heads = self .num_heads ,
520+ scale = self .scale ,
521+ sparse_mode = 3 ,
522+ workspace = workspace ,
523+ out = [output , softmax_lse ],
524+ )
525+
526+ output = output .view (num_tokens , self .num_heads ,
527+ self .head_size )
528+
529+ handle = torch .npu .graph_task_group_end (stream )
530+ graph_params .handles [num_tokens ].append (handle )
531+ else :
532+ output , _ = torch_npu .npu_fused_infer_attention_score (
533+ query = query ,
534+ key = key ,
535+ value = value ,
536+ block_table = attn_metadata .block_tables ,
537+ atten_mask = attn_metadata .attn_mask ,
538+ input_layout = "TND" ,
539+ block_size = block_size ,
540+ actual_seq_lengths = attn_metadata .query_start_loc_list ,
541+ actual_seq_lengths_kv = attn_metadata .seq_lens_list ,
542+ num_key_value_heads = self .num_kv_heads ,
543+ num_heads = self .num_heads ,
544+ scale = self .scale ,
545+ sparse_mode = 3
546+ )
547+ return output
444548
445549 def _forward_prefill_no_cache (
446550 self ,
@@ -467,15 +571,7 @@ def _forward_prefill_no_cache(
467571 mask = torch_npu .npu_format_cast (mask .contiguous (),
468572 ACL_FORMAT_FRACTAL_NZ )
469573
470- torch_npu ._npu_flash_attention (query = query ,
471- key = key ,
472- value = value ,
473- mask = mask ,
474- seq_len = attn_metadata .seq_lens ,
475- scale_value = self .scale ,
476- num_heads = self .num_heads ,
477- num_kv_heads = self .num_kv_heads ,
478- out = output )
574+ output = self .full_graph_attention (query , key , value , attn_metadata , 128 , output )
479575 assert output is not None
480576 return output [:num_tokens ]
481577
@@ -569,84 +665,12 @@ def _forward_decode_only(
569665
570666 output = output .view (batch_size , self .num_heads , self .head_size )
571667 else :
572- graph_params = get_graph_params ()
573- forward_context : ForwardContext = get_forward_context ()
574- num_tokens = query .shape [0 ]
575- if forward_context .capturing :
576- if self .torch_npu_check :
577- # Get workspace from cache or calculate it if not present.
578- workspace = graph_params .workspaces .get (num_tokens )
579- if workspace is None :
580- workspace = torch_npu ._npu_paged_attention_get_workspace (
581- query = query ,
582- key_cache = self .key_cache ,
583- value_cache = self .value_cache ,
584- num_kv_heads = self .num_kv_heads ,
585- num_heads = self .num_heads ,
586- scale_value = self .scale ,
587- block_table = attn_metadata .block_tables ,
588- context_lens = attn_metadata .seq_lens ,
589- out = output )
590- update_graph_params_workspaces (
591- num_tokens , weak_ref_tensors (workspace ))
592-
593- # Handle graph capturing mode
594- stream = torch_npu .npu .current_stream ()
595-
596- event = torch .npu .ExternalEvent ()
597- event .wait (stream )
598- event .reset (stream )
599- graph_params .events [num_tokens ].append (event )
600- graph_params .attn_params [num_tokens ].append ((
601- weak_ref_tensors (query ),
602- weak_ref_tensors (self .key_cache ),
603- weak_ref_tensors (self .value_cache ),
604- self .num_kv_heads ,
605- self .num_heads ,
606- self .scale ,
607- attn_metadata .block_tables ,
608- attn_metadata .seq_lens ,
609- weak_ref_tensors (output ),
610- ))
611-
612- torch .npu .graph_task_group_begin (stream )
613-
614- if self .torch_npu_check :
615- torch_npu ._npu_paged_attention (
616- query = query ,
617- key_cache = self .key_cache ,
618- value_cache = self .value_cache ,
619- num_kv_heads = self .num_kv_heads ,
620- num_heads = self .num_heads ,
621- scale_value = self .scale ,
622- block_table = attn_metadata .block_tables ,
623- context_lens = attn_metadata .seq_lens ,
624- out = output ,
625- workspace = workspace )
626- else :
627- torch_npu ._npu_paged_attention (
628- query = query ,
629- key_cache = self .key_cache ,
630- value_cache = self .value_cache ,
631- num_kv_heads = self .num_kv_heads ,
632- num_heads = self .num_heads ,
633- scale_value = self .scale ,
634- block_table = attn_metadata .block_tables ,
635- context_lens = attn_metadata .seq_lens ,
636- out = output )
637- handle = torch .npu .graph_task_group_end (stream )
638- graph_params .handles [num_tokens ].append (handle )
639- else :
640- torch_npu ._npu_paged_attention (
641- query = query ,
642- key_cache = self .key_cache ,
643- value_cache = self .value_cache ,
644- num_kv_heads = self .num_kv_heads ,
645- num_heads = self .num_heads ,
646- scale_value = self .scale ,
647- block_table = attn_metadata .block_tables ,
648- context_lens = attn_metadata .seq_lens ,
649- out = output )
668+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
669+ key = self .key_cache .view ( # type: ignore
670+ num_block , block_size , - 1 )
671+ value = self .value_cache .view ( # type: ignore
672+ num_block , block_size , - 1 )
673+ output = self .full_graph_attention (query , key , value , attn_metadata , block_size , output )
650674 return output
651675
652676 def _forward_v1_style (
@@ -687,43 +711,12 @@ def _forward_v1_style(
687711 attn_metadata .seq_lens = \
688712 attn_metadata .seq_lens .to (device = query .device )
689713
690- if torch .version .cann .startswith ("8.3" ):
691- # TODO:The npu_fused_infer_attention_score op is planned to
692- # be utilized in a wider range in upcoming versions.
693- num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
694- key = self .key_cache .view ( # type: ignore
695- num_block , block_size , - 1 )
696- value = self .value_cache .view ( # type: ignore
697- num_block , block_size , - 1 )
698-
699- output , _ = torch_npu .npu_fused_infer_attention_score (
700- query = query ,
701- key = key ,
702- value = value ,
703- atten_mask = attn_metadata .attn_mask ,
704- block_table = attn_metadata .block_tables ,
705- input_layout = "TND" ,
706- block_size = block_size ,
707- actual_seq_lengths = attn_metadata .actual_seq_lengths_q ,
708- actual_seq_lengths_kv = attn_metadata .seq_lens_list ,
709- num_key_value_heads = self .num_kv_heads ,
710- num_heads = self .num_heads ,
711- scale = self .scale ,
712- sparse_mode = 3 ,
713- )
714- else :
715- torch_npu ._npu_paged_attention_splitfuse (
716- query = query ,
717- key_cache = self .key_cache ,
718- value_cache = self .value_cache ,
719- mask = attn_metadata .attn_mask ,
720- block_table = attn_metadata .block_tables ,
721- seq_len = attn_metadata .query_lens ,
722- context_lens = attn_metadata .seq_lens ,
723- num_kv_heads = self .num_kv_heads ,
724- num_heads = self .num_heads ,
725- scale_value = self .scale ,
726- out = output )
714+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
715+ key = self .key_cache .view ( # type: ignore
716+ num_block , block_size , - 1 )
717+ value = self .value_cache .view ( # type: ignore
718+ num_block , block_size , - 1 )
719+ output = self .full_graph_attention (query , key , value , attn_metadata , block_size , output )
727720 return output
728721
729722 def _pack_tnd_2_bsnd (self , tensor_tnd : torch .Tensor ,
@@ -1161,26 +1154,18 @@ def forward(
11611154 )[0 ]
11621155 # V0-Style scheduler situation.
11631156 elif attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1164- intermediate_output = self ._forward_prefill_no_cache (
1157+ output = self ._forward_prefill_no_cache (
11651158 query , key , value , attn_metadata , output , num_tokens )
11661159 elif attn_metadata .attn_state == \
11671160 AscendAttentionState .PrefillCacheHit :
1168- intermediate_output = self ._forward_prefill_cache_hit (
1161+ output = self ._forward_prefill_cache_hit (
11691162 query , attn_metadata , output )
11701163 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
1171- intermediate_output = self ._forward_decode_only (
1164+ output = self ._forward_decode_only (
11721165 query , attn_metadata , output )
11731166 # Normal V1 situation.
11741167 else :
1175- if torch .version .cann .startswith ("8.3" ):
1176- # npu_fused_infer_attention_score does not support cases
1177- # where query.shape[0] != attn_metadata.query_start_loc[-1].
1178- # Thus we need unpad it here.
1179- num_tokens = attn_metadata .query_start_loc [- 1 ]
1180- query = query [:num_tokens ]
1181- intermediate_output = self ._forward_v1_style (
1168+ output = self ._forward_v1_style (
11821169 query , attn_metadata , output )
11831170
1184- output [:num_tokens ] = intermediate_output [:num_tokens ]
1185-
11861171 return output
0 commit comments