3434from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
3535 maybe_save_kv_layer_to_connector ,
3636 wait_for_kv_layer_from_connector )
37- from vllm_ascend .compilation .acl_graph import get_graph_params
37+ from vllm_ascend .compilation .acl_graph import (get_graph_params ,
38+ update_graph_params_workspaces )
3839from vllm_ascend .ops .attention import vanilla_chunked_prefill
3940from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
4041 nd_to_nz_2d , nd_to_nz_spec )
4142
4243from ..utils import weak_ref_tensors
44+ < << << << HEAD
4345
46+ == == == =
47+ >> >> >> > 0 f6e3d6 (add fullandpiecesewise graph .)
4448
4549class AscendAttentionBackend (AttentionBackend ):
4650 accept_output_buffer : bool = True
@@ -142,6 +146,9 @@ class AscendMetadata:
142146 seq_lens : torch .Tensor = None
143147
144148 query_start_loc : torch .Tensor = None
149+ seq_lens_list : List [int ] = None
150+
151+ query_start_loc_list : List [int ] = None
145152 query_lens : torch .Tensor = None
146153 # Maximum query length in the batch (None for decoding).
147154 max_query_len : Optional [int ] = None
@@ -209,8 +216,6 @@ def build(
209216 query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu [:
210217 num_reqs
211218 + 1 ]
212- query_start_loc = query_start_loc_cpu .to (self .device ,
213- non_blocking = True )
214219
215220 if is_310p ():
216221 if attn_state == AscendAttentionState .PrefillNoCache :
@@ -225,8 +230,10 @@ def build(
225230 attn_metadata = AscendMetadata (
226231 num_actual_tokens = num_actual_tokens ,
227232 block_tables = block_table ,
228- query_start_loc = query_start_loc ,
233+ query_start_loc = query_start_loc_cpu ,
234+ query_start_loc_list = query_start_loc_cpu [1 :].cpu ().int ().tolist (),
229235 query_lens = query_lens ,
236+ seq_lens_list = seq_lens .cpu ().int ().tolist (),
230237 seq_lens = seq_lens ,
231238 max_query_len = common_attn_metadata .max_query_len ,
232239 slot_mapping = slot_mapping ,
@@ -394,51 +401,151 @@ def _forward_decode_only(
394401 else :
395402 graph_params = get_graph_params ()
396403 forward_context : ForwardContext = get_forward_context ()
397- num_tokens = query .shape [0 ]
398404 if forward_context .capturing :
399- stream = torch_npu .npu .current_stream ()
400-
401- event = torch .npu .ExternalEvent ()
402- event .wait (stream )
403- event .reset (stream )
404- graph_params .events [num_tokens ].append (event )
405-
406- graph_params .attn_params [num_tokens ].append ((
407- weak_ref_tensors (query ),
408- weak_ref_tensors (self .key_cache ),
409- weak_ref_tensors (self .value_cache ),
410- self .num_kv_heads ,
411- self .num_heads ,
412- self .scale ,
413- weak_ref_tensors (attn_metadata .block_tables ),
414- attn_metadata .seq_lens ,
415- weak_ref_tensors (output ),
416- ))
417-
418- torch .npu .graph_task_group_begin (stream )
419- torch_npu ._npu_paged_attention (
420- query = query ,
421- key_cache = self .key_cache ,
422- value_cache = self .value_cache ,
423- num_kv_heads = self .num_kv_heads ,
424- num_heads = self .num_heads ,
425- scale_value = self .scale ,
426- block_table = attn_metadata .block_tables ,
427- context_lens = attn_metadata .seq_lens ,
428- out = output )
429- handle = torch .npu .graph_task_group_end (stream )
430- graph_params .handles [num_tokens ].append (handle )
405+ if torch .version .cann .startswith ("8.3" ):
406+ # Prepare tensors for attention output
407+ query_start_loc = attn_metadata .query_start_loc_list
408+ seq_lens = attn_metadata .seq_lens_list
409+ num_tokens = query_start_loc [- 1 ]
410+ query = query [:num_tokens ]
411+
412+ # Get workspace from cache or calculate it if not present.
413+ workspace = graph_params .workspaces .get (num_tokens )
414+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
415+ key = self .key_cache .view ( # type: ignore
416+ num_block , block_size , - 1 )
417+ value = self .value_cache .view ( # type: ignore
418+ num_block , block_size , - 1 )
419+ softmax_lse = torch .empty (num_tokens ,
420+ dtype = query .dtype ,
421+ device = query .device )
422+ if workspace is None :
423+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
424+ query = query ,
425+ key = key ,
426+ value = value ,
427+ block_table = attn_metadata .block_tables ,
428+ input_layout = "TND" ,
429+ block_size = block_size ,
430+ actual_seq_lengths = query_start_loc ,
431+ actual_seq_lengths_kv = seq_lens ,
432+ num_key_value_heads = self .num_kv_heads ,
433+ num_heads = self .num_heads ,
434+ sparse_mode = 0 ,
435+ scale = self .scale ,)
436+ update_graph_params_workspaces (num_tokens , workspace )
437+
438+ # Handle graph capturing mode
439+ stream = torch_npu .npu .current_stream ()
440+
441+ event = torch .npu .ExternalEvent ()
442+ event .wait (stream )
443+ event .reset (stream )
444+ graph_params .events [num_tokens ].append (event )
445+ graph_params .attn_params [num_tokens ].append ((
446+ weak_ref_tensors (query ),
447+ weak_ref_tensors (key ),
448+ weak_ref_tensors (value ),
449+ weak_ref_tensors (attn_metadata .block_tables ),
450+ block_size ,
451+ seq_lens ,
452+ query_start_loc ,
453+ self .num_kv_heads ,
454+ self .num_heads ,
455+ self .scale ,
456+ weak_ref_tensors (output ),
457+ weak_ref_tensors (softmax_lse )
458+ ))
459+
460+ torch .npu .graph_task_group_begin (stream )
461+ torch_npu .npu_fused_infer_attention_score .out (
462+ query = query ,
463+ key = key ,
464+ value = value ,
465+ block_table = attn_metadata .block_tables ,
466+ input_layout = "TND" ,
467+ block_size = block_size ,
468+ actual_seq_lengths = query_start_loc ,
469+ actual_seq_lengths_kv = seq_lens ,
470+ num_key_value_heads = self .num_kv_heads ,
471+ num_heads = self .num_heads ,
472+ scale = self .scale ,
473+ sparse_mode = 0 ,
474+ workspace = workspace ,
475+ out = [output , softmax_lse ],
476+ )
477+
478+ output = output .view (num_tokens , self .num_heads ,
479+ self .head_size )
480+
481+ handle = torch .npu .graph_task_group_end (stream )
482+ graph_params .handles [num_tokens ].append (handle )
483+ else :
484+ stream = torch_npu .npu .current_stream ()
485+
486+ event = torch .npu .ExternalEvent ()
487+ event .wait (stream )
488+ event .reset (stream )
489+ graph_params .events [num_tokens ].append (event )
490+
491+ graph_params .attn_params [num_tokens ].append ((
492+ weak_ref_tensors (query ),
493+ weak_ref_tensors (self .key_cache ),
494+ weak_ref_tensors (self .value_cache ),
495+ self .num_kv_heads ,
496+ self .num_heads ,
497+ self .scale ,
498+ weak_ref_tensors (attn_metadata .block_tables ),
499+ attn_metadata .seq_lens ,
500+ weak_ref_tensors (output ),
501+ ))
502+
503+ torch .npu .graph_task_group_begin (stream )
504+ torch_npu ._npu_paged_attention (
505+ query = query ,
506+ key_cache = self .key_cache ,
507+ value_cache = self .value_cache ,
508+ num_kv_heads = self .num_kv_heads ,
509+ num_heads = self .num_heads ,
510+ scale_value = self .scale ,
511+ block_table = attn_metadata .block_tables ,
512+ context_lens = attn_metadata .seq_lens ,
513+ out = output )
514+ handle = torch .npu .graph_task_group_end (stream )
515+ graph_params .handles [num_tokens ].append (handle )
431516 else :
432- torch_npu ._npu_paged_attention (
433- query = query ,
434- key_cache = self .key_cache ,
435- value_cache = self .value_cache ,
436- num_kv_heads = self .num_kv_heads ,
437- num_heads = self .num_heads ,
438- scale_value = self .scale ,
439- block_table = attn_metadata .block_tables ,
440- context_lens = attn_metadata .seq_lens ,
441- out = output )
517+ if torch .version .cann .startswith ("8.3" ):
518+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
519+ key = self .key_cache .view ( # type: ignore
520+ num_block , block_size , - 1 )
521+ value = self .value_cache .view ( # type: ignore
522+ num_block , block_size , - 1 )
523+
524+ output , _ = torch_npu .npu_fused_infer_attention_score (
525+ query = query ,
526+ key = key ,
527+ value = value ,
528+ block_table = attn_metadata .block_tables ,
529+ input_layout = "TND" ,
530+ block_size = block_size ,
531+ actual_seq_lengths = attn_metadata .query_start_loc_list ,
532+ actual_seq_lengths_kv = attn_metadata .seq_lens_list ,
533+ num_key_value_heads = self .num_kv_heads ,
534+ num_heads = self .num_heads ,
535+ scale = self .scale ,
536+ sparse_mode = 0
537+ )
538+ else :
539+ torch_npu ._npu_paged_attention (
540+ query = query ,
541+ key_cache = self .key_cache ,
542+ value_cache = self .value_cache ,
543+ num_kv_heads = self .num_kv_heads ,
544+ num_heads = self .num_heads ,
545+ scale_value = self .scale ,
546+ block_table = attn_metadata .block_tables ,
547+ context_lens = attn_metadata .seq_lens ,
548+ out = output )
442549 return output
443550
444551 def _forward_v1_style (
0 commit comments