4242
4343from ..utils import weak_ref_tensors
4444
45-
4645class AscendAttentionBackend (AttentionBackend ):
4746 accept_output_buffer : bool = True
4847
@@ -149,6 +148,9 @@ class AscendMetadata:
149148 actual_seq_lengths_q : List [int ] = None # type: ignore
150149
151150 query_start_loc : torch .Tensor = None
151+ seq_lens_list : List [int ] = None
152+
153+ query_start_loc_list : List [int ] = None
152154 query_lens : torch .Tensor = None
153155 # Maximum query length in the batch (None for decoding).
154156 max_query_len : Optional [int ] = None
@@ -255,8 +257,10 @@ def build(
255257 attn_metadata = AscendMetadata (
256258 num_actual_tokens = num_actual_tokens ,
257259 block_tables = block_table ,
258- query_start_loc = query_start_loc ,
260+ query_start_loc = query_start_loc_cpu ,
261+ query_start_loc_list = query_start_loc_cpu [1 :].cpu ().int ().tolist (),
259262 query_lens = query_lens ,
263+ seq_lens_list = seq_lens .cpu ().int ().tolist (),
260264 seq_lens = seq_lens ,
261265 seq_lens_list = seq_lens .tolist (),
262266 max_query_len = common_attn_metadata .max_query_len ,
@@ -427,13 +431,136 @@ def _forward_decode_only(
427431 else :
428432 graph_params = get_graph_params ()
429433 forward_context : ForwardContext = get_forward_context ()
430- num_tokens = query .shape [0 ]
431434 if forward_context .capturing :
432- if self .torch_npu_check :
435+ if torch .version .cann .startswith ("8.3" ):
436+ # Prepare tensors for attention output
437+ query_start_loc = attn_metadata .query_start_loc_list
438+ seq_lens = attn_metadata .seq_lens_list
439+ num_tokens = query_start_loc [- 1 ]
440+ query = query [:num_tokens ]
441+
433442 # Get workspace from cache or calculate it if not present.
434443 workspace = graph_params .workspaces .get (num_tokens )
444+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
445+ key = self .key_cache .view ( # type: ignore
446+ num_block , block_size , - 1 )
447+ value = self .value_cache .view ( # type: ignore
448+ num_block , block_size , - 1 )
449+ softmax_lse = torch .empty (num_tokens ,
450+ dtype = query .dtype ,
451+ device = query .device )
435452 if workspace is None :
436- workspace = torch_npu ._npu_paged_attention_get_workspace (
453+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
454+ query = query ,
455+ key = key ,
456+ value = value ,
457+ block_table = attn_metadata .block_tables ,
458+ input_layout = "TND" ,
459+ block_size = block_size ,
460+ actual_seq_lengths = query_start_loc ,
461+ actual_seq_lengths_kv = seq_lens ,
462+ num_key_value_heads = self .num_kv_heads ,
463+ num_heads = self .num_heads ,
464+ sparse_mode = 0 ,
465+ scale = self .scale ,)
466+ update_graph_params_workspaces (num_tokens , workspace )
467+
468+ # Handle graph capturing mode
469+ stream = torch_npu .npu .current_stream ()
470+
471+ event = torch .npu .ExternalEvent ()
472+ event .wait (stream )
473+ event .reset (stream )
474+ graph_params .events [num_tokens ].append (event )
475+ graph_params .attn_params [num_tokens ].append ((
476+ weak_ref_tensors (query ),
477+ weak_ref_tensors (key ),
478+ weak_ref_tensors (value ),
479+ weak_ref_tensors (attn_metadata .block_tables ),
480+ block_size ,
481+ seq_lens ,
482+ query_start_loc ,
483+ self .num_kv_heads ,
484+ self .num_heads ,
485+ self .scale ,
486+ weak_ref_tensors (output ),
487+ weak_ref_tensors (softmax_lse )
488+ ))
489+
490+ torch .npu .graph_task_group_begin (stream )
491+ torch_npu .npu_fused_infer_attention_score .out (
492+ query = query ,
493+ key = key ,
494+ value = value ,
495+ block_table = attn_metadata .block_tables ,
496+ input_layout = "TND" ,
497+ block_size = block_size ,
498+ actual_seq_lengths = query_start_loc ,
499+ actual_seq_lengths_kv = seq_lens ,
500+ num_key_value_heads = self .num_kv_heads ,
501+ num_heads = self .num_heads ,
502+ scale = self .scale ,
503+ sparse_mode = 0 ,
504+ workspace = workspace ,
505+ out = [output , softmax_lse ],
506+ )
507+
508+ output = output .view (num_tokens , self .num_heads ,
509+ self .head_size )
510+
511+ handle = torch .npu .graph_task_group_end (stream )
512+ graph_params .handles [num_tokens ].append (handle )
513+ else :
514+ if self .torch_npu_check :
515+ # Get workspace from cache or calculate it if not present.
516+ workspace = graph_params .workspaces .get (num_tokens )
517+ if workspace is None :
518+ workspace = torch_npu ._npu_paged_attention_get_workspace (
519+ query = query ,
520+ key_cache = self .key_cache ,
521+ value_cache = self .value_cache ,
522+ num_kv_heads = self .num_kv_heads ,
523+ num_heads = self .num_heads ,
524+ scale_value = self .scale ,
525+ block_table = attn_metadata .block_tables ,
526+ context_lens = attn_metadata .seq_lens ,
527+ out = output )
528+ update_graph_params_workspaces (num_tokens , workspace )
529+ # Handle graph capturing mode
530+ stream = torch_npu .npu .current_stream ()
531+
532+ event = torch .npu .ExternalEvent ()
533+ event .wait (stream )
534+ event .reset (stream )
535+ graph_params .events [num_tokens ].append (event )
536+ graph_params .attn_params [num_tokens ].append ((
537+ weak_ref_tensors (query ),
538+ weak_ref_tensors (self .key_cache ),
539+ weak_ref_tensors (self .value_cache ),
540+ self .num_kv_heads ,
541+ self .num_heads ,
542+ self .scale ,
543+ weak_ref_tensors (attn_metadata .block_tables ),
544+ attn_metadata .seq_lens ,
545+ weak_ref_tensors (output ),
546+ ))
547+
548+ torch .npu .graph_task_group_begin (stream )
549+
550+ if self .torch_npu_check :
551+ torch_npu ._npu_paged_attention (
552+ query = query ,
553+ key_cache = self .key_cache ,
554+ value_cache = self .value_cache ,
555+ num_kv_heads = self .num_kv_heads ,
556+ num_heads = self .num_heads ,
557+ scale_value = self .scale ,
558+ block_table = attn_metadata .block_tables ,
559+ context_lens = attn_metadata .seq_lens ,
560+ out = output ,
561+ workspace = workspace )
562+ else :
563+ torch_npu ._npu_paged_attention (
437564 query = query ,
438565 key_cache = self .key_cache ,
439566 value_cache = self .value_cache ,
@@ -443,41 +570,27 @@ def _forward_decode_only(
443570 block_table = attn_metadata .block_tables ,
444571 context_lens = attn_metadata .seq_lens ,
445572 out = output )
446- update_graph_params_workspaces (num_tokens , workspace )
447-
448- # Handle graph capturing mode
449- stream = torch_npu .npu .current_stream ()
450-
451- event = torch .npu .ExternalEvent ()
452- event .wait (stream )
453- event .reset (stream )
454- graph_params .events [num_tokens ].append (event )
455- graph_params .attn_params [num_tokens ].append ((
456- weak_ref_tensors (query ),
457- weak_ref_tensors (self .key_cache ),
458- weak_ref_tensors (self .value_cache ),
459- self .num_kv_heads ,
460- self .num_heads ,
461- self .scale ,
462- weak_ref_tensors (attn_metadata .block_tables ),
463- attn_metadata .seq_lens ,
464- weak_ref_tensors (output ),
465- ))
466-
467- torch .npu .graph_task_group_begin (stream )
468-
469- if self .torch_npu_check :
470- torch_npu ._npu_paged_attention (
573+ else :
574+ if torch .version .cann .startswith ("8.3" ):
575+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
576+ key = self .key_cache .view (
577+ num_block , block_size , - 1 )
578+ value = self .value_cache .view (
579+ num_block , block_size , - 1 )
580+ output , _ = torch_npu .npu_fused_infer_attention_score (
471581 query = query ,
472- key_cache = self .key_cache ,
473- value_cache = self .value_cache ,
474- num_kv_heads = self .num_kv_heads ,
475- num_heads = self .num_heads ,
476- scale_value = self .scale ,
582+ key = key ,
583+ value = value ,
477584 block_table = attn_metadata .block_tables ,
478- context_lens = attn_metadata .seq_lens ,
479- out = output ,
480- workspace = workspace )
585+ input_layout = "TND" ,
586+ block_size = block_size ,
587+ actual_seq_lengths = attn_metadata .query_start_loc_list ,
588+ actual_seq_lengths_kv = attn_metadata .seq_lens_list ,
589+ num_key_value_heads = self .num_kv_heads ,
590+ num_heads = self .num_heads ,
591+ scale = self .scale ,
592+ sparse_mode = 0
593+ )
481594 else :
482595 torch_npu ._npu_paged_attention (
483596 query = query ,
@@ -489,19 +602,6 @@ def _forward_decode_only(
489602 block_table = attn_metadata .block_tables ,
490603 context_lens = attn_metadata .seq_lens ,
491604 out = output )
492- handle = torch .npu .graph_task_group_end (stream )
493- graph_params .handles [num_tokens ].append (handle )
494- else :
495- torch_npu ._npu_paged_attention (
496- query = query ,
497- key_cache = self .key_cache ,
498- value_cache = self .value_cache ,
499- num_kv_heads = self .num_kv_heads ,
500- num_heads = self .num_heads ,
501- scale_value = self .scale ,
502- block_table = attn_metadata .block_tables ,
503- context_lens = attn_metadata .seq_lens ,
504- out = output )
505605 return output
506606
507607 def _forward_v1_style (
0 commit comments