3434from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
3535 maybe_save_kv_layer_to_connector ,
3636 wait_for_kv_layer_from_connector )
37- from vllm_ascend .compilation .acl_graph import get_graph_params
37+ from vllm_ascend .compilation .acl_graph import (get_graph_params ,
38+ update_graph_params_workspaces )
3839from vllm_ascend .ops .attention import vanilla_chunked_prefill
3940from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
4041 nd_to_nz_2d , nd_to_nz_spec )
4142
4243from ..utils import weak_ref_tensors
4344
44-
4545class AscendAttentionBackend (AttentionBackend ):
4646 accept_output_buffer : bool = True
4747
@@ -142,6 +142,9 @@ class AscendMetadata:
142142 seq_lens : torch .Tensor = None
143143
144144 query_start_loc : torch .Tensor = None
145+ seq_lens_list : List [int ] = None
146+
147+ query_start_loc_list : List [int ] = None
145148 query_lens : torch .Tensor = None
146149 # Maximum query length in the batch (None for decoding).
147150 max_query_len : Optional [int ] = None
@@ -209,8 +212,6 @@ def build(
209212 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
210213 num_reqs
211214 + 1 ]
212- query_start_loc = query_start_loc_cpu .to (self .device ,
213- non_blocking = True )
214215
215216 if is_310p ():
216217 if attn_state == AscendAttentionState .PrefillNoCache :
@@ -225,8 +226,10 @@ def build(
225226 attn_metadata = AscendMetadata (
226227 num_actual_tokens = num_actual_tokens ,
227228 block_tables = block_table ,
228- query_start_loc = query_start_loc ,
229+ query_start_loc = query_start_loc_cpu ,
230+ query_start_loc_list = query_start_loc_cpu [1 :].cpu ().int ().tolist (),
229231 query_lens = query_lens ,
232+ seq_lens_list = seq_lens .cpu ().int ().tolist (),
230233 seq_lens = seq_lens ,
231234 max_query_len = common_attn_metadata .max_query_len ,
232235 slot_mapping = slot_mapping ,
@@ -394,51 +397,151 @@ def _forward_decode_only(
394397 else :
395398 graph_params = get_graph_params ()
396399 forward_context : ForwardContext = get_forward_context ()
397- num_tokens = query .shape [0 ]
398400 if forward_context .capturing :
399- stream = torch_npu .npu .current_stream ()
400-
401- event = torch .npu .ExternalEvent ()
402- event .wait (stream )
403- event .reset (stream )
404- graph_params .events [num_tokens ].append (event )
405-
406- graph_params .attn_params [num_tokens ].append ((
407- weak_ref_tensors (query ),
408- weak_ref_tensors (self .key_cache ),
409- weak_ref_tensors (self .value_cache ),
410- self .num_kv_heads ,
411- self .num_heads ,
412- self .scale ,
413- weak_ref_tensors (attn_metadata .block_tables ),
414- attn_metadata .seq_lens ,
415- weak_ref_tensors (output ),
416- ))
417-
418- torch .npu .graph_task_group_begin (stream )
419- torch_npu ._npu_paged_attention (
420- query = query ,
421- key_cache = self .key_cache ,
422- value_cache = self .value_cache ,
423- num_kv_heads = self .num_kv_heads ,
424- num_heads = self .num_heads ,
425- scale_value = self .scale ,
426- block_table = attn_metadata .block_tables ,
427- context_lens = attn_metadata .seq_lens ,
428- out = output )
429- handle = torch .npu .graph_task_group_end (stream )
430- graph_params .handles [num_tokens ].append (handle )
401+ if torch .version .cann .startswith ("8.3" ):
402+ # Prepare tensors for attention output
403+ query_start_loc = attn_metadata .query_start_loc_list
404+ seq_lens = attn_metadata .seq_lens_list
405+ num_tokens = query_start_loc [- 1 ]
406+ query = query [:num_tokens ]
407+
408+ # Get workspace from cache or calculate it if not present.
409+ workspace = graph_params .workspaces .get (num_tokens )
410+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
411+ key = self .key_cache .view ( # type: ignore
412+ num_block , block_size , - 1 )
413+ value = self .value_cache .view ( # type: ignore
414+ num_block , block_size , - 1 )
415+ softmax_lse = torch .empty (num_tokens ,
416+ dtype = query .dtype ,
417+ device = query .device )
418+ if workspace is None :
419+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
420+ query = query ,
421+ key = key ,
422+ value = value ,
423+ block_table = attn_metadata .block_tables ,
424+ input_layout = "TND" ,
425+ block_size = block_size ,
426+ actual_seq_lengths = query_start_loc ,
427+ actual_seq_lengths_kv = seq_lens ,
428+ num_key_value_heads = self .num_kv_heads ,
429+ num_heads = self .num_heads ,
430+ sparse_mode = 0 ,
431+ scale = self .scale ,)
432+ update_graph_params_workspaces (num_tokens , workspace )
433+
434+ # Handle graph capturing mode
435+ stream = torch_npu .npu .current_stream ()
436+
437+ event = torch .npu .ExternalEvent ()
438+ event .wait (stream )
439+ event .reset (stream )
440+ graph_params .events [num_tokens ].append (event )
441+ graph_params .attn_params [num_tokens ].append ((
442+ weak_ref_tensors (query ),
443+ weak_ref_tensors (key ),
444+ weak_ref_tensors (value ),
445+ weak_ref_tensors (attn_metadata .block_tables ),
446+ block_size ,
447+ seq_lens ,
448+ query_start_loc ,
449+ self .num_kv_heads ,
450+ self .num_heads ,
451+ self .scale ,
452+ weak_ref_tensors (output ),
453+ weak_ref_tensors (softmax_lse )
454+ ))
455+
456+ torch .npu .graph_task_group_begin (stream )
457+ torch_npu .npu_fused_infer_attention_score .out (
458+ query = query ,
459+ key = key ,
460+ value = value ,
461+ block_table = attn_metadata .block_tables ,
462+ input_layout = "TND" ,
463+ block_size = block_size ,
464+ actual_seq_lengths = query_start_loc ,
465+ actual_seq_lengths_kv = seq_lens ,
466+ num_key_value_heads = self .num_kv_heads ,
467+ num_heads = self .num_heads ,
468+ scale = self .scale ,
469+ sparse_mode = 0 ,
470+ workspace = workspace ,
471+ out = [output , softmax_lse ],
472+ )
473+
474+ output = output .view (num_tokens , self .num_heads ,
475+ self .head_size )
476+
477+ handle = torch .npu .graph_task_group_end (stream )
478+ graph_params .handles [num_tokens ].append (handle )
479+ else :
480+ stream = torch_npu .npu .current_stream ()
481+
482+ event = torch .npu .ExternalEvent ()
483+ event .wait (stream )
484+ event .reset (stream )
485+ graph_params .events [num_tokens ].append (event )
486+
487+ graph_params .attn_params [num_tokens ].append ((
488+ weak_ref_tensors (query ),
489+ weak_ref_tensors (self .key_cache ),
490+ weak_ref_tensors (self .value_cache ),
491+ self .num_kv_heads ,
492+ self .num_heads ,
493+ self .scale ,
494+ weak_ref_tensors (attn_metadata .block_tables ),
495+ attn_metadata .seq_lens ,
496+ weak_ref_tensors (output ),
497+ ))
498+
499+ torch .npu .graph_task_group_begin (stream )
500+ torch_npu ._npu_paged_attention (
501+ query = query ,
502+ key_cache = self .key_cache ,
503+ value_cache = self .value_cache ,
504+ num_kv_heads = self .num_kv_heads ,
505+ num_heads = self .num_heads ,
506+ scale_value = self .scale ,
507+ block_table = attn_metadata .block_tables ,
508+ context_lens = attn_metadata .seq_lens ,
509+ out = output )
510+ handle = torch .npu .graph_task_group_end (stream )
511+ graph_params .handles [num_tokens ].append (handle )
431512 else :
432- torch_npu ._npu_paged_attention (
433- query = query ,
434- key_cache = self .key_cache ,
435- value_cache = self .value_cache ,
436- num_kv_heads = self .num_kv_heads ,
437- num_heads = self .num_heads ,
438- scale_value = self .scale ,
439- block_table = attn_metadata .block_tables ,
440- context_lens = attn_metadata .seq_lens ,
441- out = output )
513+ if torch .version .cann .startswith ("8.3" ):
514+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
515+ key = self .key_cache .view ( # type: ignore
516+ num_block , block_size , - 1 )
517+ value = self .value_cache .view ( # type: ignore
518+ num_block , block_size , - 1 )
519+
520+ output , _ = torch_npu .npu_fused_infer_attention_score (
521+ query = query ,
522+ key = key ,
523+ value = value ,
524+ block_table = attn_metadata .block_tables ,
525+ input_layout = "TND" ,
526+ block_size = block_size ,
527+ actual_seq_lengths = attn_metadata .query_start_loc_list ,
528+ actual_seq_lengths_kv = attn_metadata .seq_lens_list ,
529+ num_key_value_heads = self .num_kv_heads ,
530+ num_heads = self .num_heads ,
531+ scale = self .scale ,
532+ sparse_mode = 0
533+ )
534+ else :
535+ torch_npu ._npu_paged_attention (
536+ query = query ,
537+ key_cache = self .key_cache ,
538+ value_cache = self .value_cache ,
539+ num_kv_heads = self .num_kv_heads ,
540+ num_heads = self .num_heads ,
541+ scale_value = self .scale ,
542+ block_table = attn_metadata .block_tables ,
543+ context_lens = attn_metadata .seq_lens ,
544+ out = output )
442545 return output
443546
444547 def _forward_v1_style (
0 commit comments