4343
4444from ..utils import weak_ref_tensors
4545
46-
4746class AscendAttentionBackend (AttentionBackend ):
4847 accept_output_buffer : bool = True
4948
@@ -144,6 +143,9 @@ class AscendMetadata:
144143 seq_lens : torch .Tensor = None
145144
146145 query_start_loc : torch .Tensor = None
146+ seq_lens_list : List [int ] = None
147+
148+ query_start_loc_list : List [int ] = None
147149 query_lens : torch .Tensor = None
148150 # Maximum query length in the batch (None for decoding).
149151 max_query_len : Optional [int ] = None
@@ -211,8 +213,6 @@ def build(
211213 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
212214 num_reqs
213215 + 1 ]
214- query_start_loc = query_start_loc_cpu .to (self .device ,
215- non_blocking = True )
216216
217217 if is_310p ():
218218 if attn_state == AscendAttentionState .PrefillNoCache :
@@ -227,8 +227,10 @@ def build(
227227 attn_metadata = AscendMetadata (
228228 num_actual_tokens = num_actual_tokens ,
229229 block_tables = block_table ,
230- query_start_loc = query_start_loc ,
230+ query_start_loc = query_start_loc_cpu ,
231+ query_start_loc_list = query_start_loc_cpu [1 :].cpu ().int ().tolist (),
231232 query_lens = query_lens ,
233+ seq_lens_list = seq_lens .cpu ().int ().tolist (),
232234 seq_lens = seq_lens ,
233235 max_query_len = common_attn_metadata .max_query_len ,
234236 slot_mapping = slot_mapping ,
@@ -397,13 +399,136 @@ def _forward_decode_only(
397399 else :
398400 graph_params = get_graph_params ()
399401 forward_context : ForwardContext = get_forward_context ()
400- num_tokens = query .shape [0 ]
401402 if forward_context .capturing :
402- if self .torch_npu_check :
403+ if torch .version .cann .startswith ("8.3" ):
404+ # Prepare tensors for attention output
405+ query_start_loc = attn_metadata .query_start_loc_list
406+ seq_lens = attn_metadata .seq_lens_list
407+ num_tokens = query_start_loc [- 1 ]
408+ query = query [:num_tokens ]
409+
403410 # Get workspace from cache or calculate it if not present.
404411 workspace = graph_params .workspaces .get (num_tokens )
412+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
413+ key = self .key_cache .view ( # type: ignore
414+ num_block , block_size , - 1 )
415+ value = self .value_cache .view ( # type: ignore
416+ num_block , block_size , - 1 )
417+ softmax_lse = torch .empty (num_tokens ,
418+ dtype = query .dtype ,
419+ device = query .device )
405420 if workspace is None :
406- workspace = torch_npu ._npu_paged_attention_get_workspace (
421+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
422+ query = query ,
423+ key = key ,
424+ value = value ,
425+ block_table = attn_metadata .block_tables ,
426+ input_layout = "TND" ,
427+ block_size = block_size ,
428+ actual_seq_lengths = query_start_loc ,
429+ actual_seq_lengths_kv = seq_lens ,
430+ num_key_value_heads = self .num_kv_heads ,
431+ num_heads = self .num_heads ,
432+ sparse_mode = 0 ,
433+ scale = self .scale ,)
434+ update_graph_params_workspaces (num_tokens , workspace )
435+
436+ # Handle graph capturing mode
437+ stream = torch_npu .npu .current_stream ()
438+
439+ event = torch .npu .ExternalEvent ()
440+ event .wait (stream )
441+ event .reset (stream )
442+ graph_params .events [num_tokens ].append (event )
443+ graph_params .attn_params [num_tokens ].append ((
444+ weak_ref_tensors (query ),
445+ weak_ref_tensors (key ),
446+ weak_ref_tensors (value ),
447+ weak_ref_tensors (attn_metadata .block_tables ),
448+ block_size ,
449+ seq_lens ,
450+ query_start_loc ,
451+ self .num_kv_heads ,
452+ self .num_heads ,
453+ self .scale ,
454+ weak_ref_tensors (output ),
455+ weak_ref_tensors (softmax_lse )
456+ ))
457+
458+ torch .npu .graph_task_group_begin (stream )
459+ torch_npu .npu_fused_infer_attention_score .out (
460+ query = query ,
461+ key = key ,
462+ value = value ,
463+ block_table = attn_metadata .block_tables ,
464+ input_layout = "TND" ,
465+ block_size = block_size ,
466+ actual_seq_lengths = query_start_loc ,
467+ actual_seq_lengths_kv = seq_lens ,
468+ num_key_value_heads = self .num_kv_heads ,
469+ num_heads = self .num_heads ,
470+ scale = self .scale ,
471+ sparse_mode = 0 ,
472+ workspace = workspace ,
473+ out = [output , softmax_lse ],
474+ )
475+
476+ output = output .view (num_tokens , self .num_heads ,
477+ self .head_size )
478+
479+ handle = torch .npu .graph_task_group_end (stream )
480+ graph_params .handles [num_tokens ].append (handle )
481+ else :
482+ if self .torch_npu_check :
483+ # Get workspace from cache or calculate it if not present.
484+ workspace = graph_params .workspaces .get (num_tokens )
485+ if workspace is None :
486+ workspace = torch_npu ._npu_paged_attention_get_workspace (
487+ query = query ,
488+ key_cache = self .key_cache ,
489+ value_cache = self .value_cache ,
490+ num_kv_heads = self .num_kv_heads ,
491+ num_heads = self .num_heads ,
492+ scale_value = self .scale ,
493+ block_table = attn_metadata .block_tables ,
494+ context_lens = attn_metadata .seq_lens ,
495+ out = output )
496+ update_graph_params_workspaces (num_tokens , workspace )
497+ # Handle graph capturing mode
498+ stream = torch_npu .npu .current_stream ()
499+
500+ event = torch .npu .ExternalEvent ()
501+ event .wait (stream )
502+ event .reset (stream )
503+ graph_params .events [num_tokens ].append (event )
504+ graph_params .attn_params [num_tokens ].append ((
505+ weak_ref_tensors (query ),
506+ weak_ref_tensors (self .key_cache ),
507+ weak_ref_tensors (self .value_cache ),
508+ self .num_kv_heads ,
509+ self .num_heads ,
510+ self .scale ,
511+ weak_ref_tensors (attn_metadata .block_tables ),
512+ attn_metadata .seq_lens ,
513+ weak_ref_tensors (output ),
514+ ))
515+
516+ torch .npu .graph_task_group_begin (stream )
517+
518+ if self .torch_npu_check :
519+ torch_npu ._npu_paged_attention (
520+ query = query ,
521+ key_cache = self .key_cache ,
522+ value_cache = self .value_cache ,
523+ num_kv_heads = self .num_kv_heads ,
524+ num_heads = self .num_heads ,
525+ scale_value = self .scale ,
526+ block_table = attn_metadata .block_tables ,
527+ context_lens = attn_metadata .seq_lens ,
528+ out = output ,
529+ workspace = workspace )
530+ else :
531+ torch_npu ._npu_paged_attention (
407532 query = query ,
408533 key_cache = self .key_cache ,
409534 value_cache = self .value_cache ,
@@ -413,41 +538,27 @@ def _forward_decode_only(
413538 block_table = attn_metadata .block_tables ,
414539 context_lens = attn_metadata .seq_lens ,
415540 out = output )
416- update_graph_params_workspaces (num_tokens , workspace )
417-
418- # Handle graph capturing mode
419- stream = torch_npu .npu .current_stream ()
420-
421- event = torch .npu .ExternalEvent ()
422- event .wait (stream )
423- event .reset (stream )
424- graph_params .events [num_tokens ].append (event )
425- graph_params .attn_params [num_tokens ].append ((
426- weak_ref_tensors (query ),
427- weak_ref_tensors (self .key_cache ),
428- weak_ref_tensors (self .value_cache ),
429- self .num_kv_heads ,
430- self .num_heads ,
431- self .scale ,
432- weak_ref_tensors (attn_metadata .block_tables ),
433- attn_metadata .seq_lens ,
434- weak_ref_tensors (output ),
435- ))
436-
437- torch .npu .graph_task_group_begin (stream )
438-
439- if self .torch_npu_check :
440- torch_npu ._npu_paged_attention (
541+ else :
542+ if torch .version .cann .startswith ("8.3" ):
543+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
544+ key = self .key_cache .view (
545+ num_block , block_size , - 1 )
546+ value = self .value_cache .view (
547+ num_block , block_size , - 1 )
548+ output , _ = torch_npu .npu_fused_infer_attention_score (
441549 query = query ,
442- key_cache = self .key_cache ,
443- value_cache = self .value_cache ,
444- num_kv_heads = self .num_kv_heads ,
445- num_heads = self .num_heads ,
446- scale_value = self .scale ,
550+ key = key ,
551+ value = value ,
447552 block_table = attn_metadata .block_tables ,
448- context_lens = attn_metadata .seq_lens ,
449- out = output ,
450- workspace = workspace )
553+ input_layout = "TND" ,
554+ block_size = block_size ,
555+ actual_seq_lengths = attn_metadata .query_start_loc_list ,
556+ actual_seq_lengths_kv = attn_metadata .seq_lens_list ,
557+ num_key_value_heads = self .num_kv_heads ,
558+ num_heads = self .num_heads ,
559+ scale = self .scale ,
560+ sparse_mode = 0
561+ )
451562 else :
452563 torch_npu ._npu_paged_attention (
453564 query = query ,
@@ -459,19 +570,6 @@ def _forward_decode_only(
459570 block_table = attn_metadata .block_tables ,
460571 context_lens = attn_metadata .seq_lens ,
461572 out = output )
462- handle = torch .npu .graph_task_group_end (stream )
463- graph_params .handles [num_tokens ].append (handle )
464- else :
465- torch_npu ._npu_paged_attention (
466- query = query ,
467- key_cache = self .key_cache ,
468- value_cache = self .value_cache ,
469- num_kv_heads = self .num_kv_heads ,
470- num_heads = self .num_heads ,
471- scale_value = self .scale ,
472- block_table = attn_metadata .block_tables ,
473- context_lens = attn_metadata .seq_lens ,
474- out = output )
475573 return output
476574
477575 def _forward_v1_style (
0 commit comments