3939from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
4040 nd_to_nz_2d , nd_to_nz_spec )
4141
42+ from ..utils import weak_ref_tensors
4243
4344class AscendAttentionBackend (AttentionBackend ):
4445 accept_output_buffer : bool = True
@@ -140,6 +141,9 @@ class AscendMetadata:
140141 seq_lens : torch .Tensor = None
141142
142143 query_start_loc : torch .Tensor = None
144+ seq_lens_list : List [int ] = None
145+
146+ query_start_loc_list : List [int ] = None
143147 query_lens : torch .Tensor = None
144148 # Maximum query length in the batch (None for decoding).
145149 max_query_len : Optional [int ] = None
@@ -207,8 +211,6 @@ def build(
207211 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
208212 num_reqs
209213 + 1 ]
210- query_start_loc = query_start_loc_cpu .to (self .device ,
211- non_blocking = True )
212214
213215 if is_310p ():
214216 if attn_state == AscendAttentionState .PrefillNoCache :
@@ -223,8 +225,10 @@ def build(
223225 attn_metadata = AscendMetadata (
224226 num_actual_tokens = num_actual_tokens ,
225227 block_tables = block_table ,
226- query_start_loc = query_start_loc ,
228+ query_start_loc = query_start_loc_cpu ,
229+ query_start_loc_list = query_start_loc_cpu [1 :].cpu ().int ().tolist (),
227230 query_lens = query_lens ,
231+ seq_lens_list = seq_lens .cpu ().int ().tolist (),
228232 seq_lens = seq_lens ,
229233 max_query_len = common_attn_metadata .max_query_len ,
230234 slot_mapping = slot_mapping ,
@@ -391,51 +395,151 @@ def _forward_decode_only(
391395 else :
392396 graph_params = get_graph_params ()
393397 forward_context : ForwardContext = get_forward_context ()
394- num_tokens = query .shape [0 ]
395398 if forward_context .capturing :
396- stream = torch_npu .npu .current_stream ()
397-
398- event = torch .npu .ExternalEvent ()
399- event .wait (stream )
400- event .reset (stream )
401- graph_params .events [num_tokens ].append (event )
402-
403- graph_params .attn_params [num_tokens ].append ((
404- query ,
405- self .key_cache ,
406- self .value_cache ,
407- self .num_kv_heads ,
408- self .num_heads ,
409- self .scale ,
410- attn_metadata .block_tables ,
411- attn_metadata .seq_lens ,
412- output ,
413- ))
414-
415- torch .npu .graph_task_group_begin (stream )
416- torch_npu ._npu_paged_attention (
417- query = query ,
418- key_cache = self .key_cache ,
419- value_cache = self .value_cache ,
420- num_kv_heads = self .num_kv_heads ,
421- num_heads = self .num_heads ,
422- scale_value = self .scale ,
423- block_table = attn_metadata .block_tables ,
424- context_lens = attn_metadata .seq_lens ,
425- out = output )
426- handle = torch .npu .graph_task_group_end (stream )
427- graph_params .handles [num_tokens ].append (handle )
399+ if torch .version .cann .startswith ("8.3" ):
400+ # Prepare tensors for attention output
401+ query_start_loc = attn_metadata .query_start_loc_list
402+ seq_lens = attn_metadata .seq_lens_list
403+ num_tokens = query_start_loc [- 1 ]
404+ query = query [:num_tokens ]
405+
406+ # Get workspace from cache or calculate it if not present.
407+ workspace = graph_params .workspaces .get (num_tokens )
408+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
409+ key = self .key_cache .view ( # type: ignore
410+ num_block , block_size , - 1 )
411+ value = self .value_cache .view ( # type: ignore
412+ num_block , block_size , - 1 )
413+ softmax_lse = torch .empty (num_tokens ,
414+ dtype = query .dtype ,
415+ device = query .device )
416+ if workspace is None :
417+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
418+ query = query ,
419+ key = key ,
420+ value = value ,
421+ block_table = attn_metadata .block_tables ,
422+ input_layout = "TND" ,
423+ block_size = block_size ,
424+ actual_seq_lengths = query_start_loc ,
425+ actual_seq_lengths_kv = seq_lens ,
426+ num_key_value_heads = self .num_kv_heads ,
427+ num_heads = self .num_heads ,
428+ sparse_mode = 0 ,
429+ scale = self .scale ,)
430+ graph_params .workspaces [num_tokens ] = weak_ref_tensors (workspace )
431+
432+ # Handle graph capturing mode
433+ stream = torch_npu .npu .current_stream ()
434+
435+ event = torch .npu .ExternalEvent ()
436+ event .wait (stream )
437+ event .reset (stream )
438+ graph_params .events [num_tokens ].append (event )
439+ graph_params .attn_params [num_tokens ].append ((
440+ weak_ref_tensors (query ),
441+ weak_ref_tensors (key ),
442+ weak_ref_tensors (value ),
443+ weak_ref_tensors (attn_metadata .block_tables ),
444+ block_size ,
445+ seq_lens ,
446+ query_start_loc ,
447+ self .num_kv_heads ,
448+ self .num_heads ,
449+ self .scale ,
450+ weak_ref_tensors (output ),
451+ weak_ref_tensors (softmax_lse )
452+ ))
453+
454+ torch .npu .graph_task_group_begin (stream )
455+ torch_npu .npu_fused_infer_attention_score .out (
456+ query = query ,
457+ key = key ,
458+ value = value ,
459+ block_table = attn_metadata .block_tables ,
460+ input_layout = "TND" ,
461+ block_size = block_size ,
462+ actual_seq_lengths = query_start_loc ,
463+ actual_seq_lengths_kv = seq_lens ,
464+ num_key_value_heads = self .num_kv_heads ,
465+ num_heads = self .num_heads ,
466+ scale = self .scale ,
467+ sparse_mode = 0 ,
468+ workspace = workspace ,
469+ out = [output , softmax_lse ],
470+ )
471+
472+ output = output .view (num_tokens , self .num_heads ,
473+ self .head_size )
474+
475+ handle = torch .npu .graph_task_group_end (stream )
476+ graph_params .handles [num_tokens ].append (handle )
477+ else :
478+ stream = torch_npu .npu .current_stream ()
479+
480+ event = torch .npu .ExternalEvent ()
481+ event .wait (stream )
482+ event .reset (stream )
483+ graph_params .events [num_tokens ].append (event )
484+
485+ graph_params .attn_params [num_tokens ].append ((
486+ weak_ref_tensors (query ),
487+ weak_ref_tensors (self .key_cache ),
488+ weak_ref_tensors (self .value_cache ),
489+ self .num_kv_heads ,
490+ self .num_heads ,
491+ self .scale ,
492+ weak_ref_tensors (attn_metadata .block_tables ),
493+ attn_metadata .seq_lens ,
494+ weak_ref_tensors (output ),
495+ ))
496+
497+ torch .npu .graph_task_group_begin (stream )
498+ torch_npu ._npu_paged_attention (
499+ query = query ,
500+ key_cache = self .key_cache ,
501+ value_cache = self .value_cache ,
502+ num_kv_heads = self .num_kv_heads ,
503+ num_heads = self .num_heads ,
504+ scale_value = self .scale ,
505+ block_table = attn_metadata .block_tables ,
506+ context_lens = attn_metadata .seq_lens ,
507+ out = output )
508+ handle = torch .npu .graph_task_group_end (stream )
509+ graph_params .handles [num_tokens ].append (handle )
428510 else :
429- torch_npu ._npu_paged_attention (
430- query = query ,
431- key_cache = self .key_cache ,
432- value_cache = self .value_cache ,
433- num_kv_heads = self .num_kv_heads ,
434- num_heads = self .num_heads ,
435- scale_value = self .scale ,
436- block_table = attn_metadata .block_tables ,
437- context_lens = attn_metadata .seq_lens ,
438- out = output )
511+ if torch .version .cann .startswith ("8.3" ):
512+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
513+ key = self .key_cache .view ( # type: ignore
514+ num_block , block_size , - 1 )
515+ value = self .value_cache .view ( # type: ignore
516+ num_block , block_size , - 1 )
517+
518+ output , _ = torch_npu .npu_fused_infer_attention_score (
519+ query = query ,
520+ key = key ,
521+ value = value ,
522+ block_table = attn_metadata .block_tables ,
523+ input_layout = "TND" ,
524+ block_size = block_size ,
525+ actual_seq_lengths = attn_metadata .query_start_loc_list ,
526+ actual_seq_lengths_kv = attn_metadata .seq_lens_list ,
527+ num_key_value_heads = self .num_kv_heads ,
528+ num_heads = self .num_heads ,
529+ scale = self .scale ,
530+ sparse_mode = 0
531+ )
532+ else :
533+ torch_npu ._npu_paged_attention (
534+ query = query ,
535+ key_cache = self .key_cache ,
536+ value_cache = self .value_cache ,
537+ num_kv_heads = self .num_kv_heads ,
538+ num_heads = self .num_heads ,
539+ scale_value = self .scale ,
540+ block_table = attn_metadata .block_tables ,
541+ context_lens = attn_metadata .seq_lens ,
542+ out = output )
439543 return output
440544
441545 def _forward_v1_style (
0 commit comments