5050 update_graph_params_workspaces )
5151from vllm_ascend .ops .attention import vanilla_chunked_prefill
5252from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
53- nd_to_nz_2d , nd_to_nz_spec ,
53+ is_Ascend950 , nd_to_nz_2d , nd_to_nz_spec ,
5454 prefill_context_parallel_enable ,
5555 weak_ref_tensors )
5656
@@ -703,15 +703,29 @@ def _forward_prefill_no_cache(
703703 mask = torch_npu .npu_format_cast (mask .contiguous (),
704704 ACL_FORMAT_FRACTAL_NZ )
705705
706- torch_npu ._npu_flash_attention (query = query ,
707- key = key ,
708- value = value ,
709- mask = mask ,
710- seq_len = attn_metadata .seq_lens ,
711- scale_value = self .scale ,
712- num_heads = self .num_heads ,
713- num_kv_heads = self .num_kv_heads ,
714- out = output )
706+ if is_Ascend950 ():
707+ num_tokens = attn_metadata .query_start_loc [- 1 ]
708+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
709+ query [:num_tokens ],
710+ key [:num_tokens ],
711+ value [:num_tokens ],
712+ atten_mask = mask .to (torch .bool ),
713+ actual_seq_qlen = attn_metadata .query_lens .cumsum (0 ),
714+ actual_seq_kvlen = attn_metadata .seq_lens .cumsum (0 ),
715+ num_query_heads = self .num_heads ,
716+ num_key_value_heads = self .num_kv_heads ,
717+ input_layout = "TND" ,
718+ softmax_scale = self .scale )
719+ else :
720+ torch_npu ._npu_flash_attention (query = query ,
721+ key = key ,
722+ value = value ,
723+ mask = mask ,
724+ seq_len = attn_metadata .seq_lens ,
725+ scale_value = self .scale ,
726+ num_heads = self .num_heads ,
727+ num_kv_heads = self .num_kv_heads ,
728+ out = output )
715729 assert output is not None
716730 return output [:num_tokens ]
717731
@@ -729,6 +743,27 @@ def _forward_prefill_cache_hit(
729743 block_table = attn_metadata .block_tables [:batch_size , :]
730744 num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
731745
746+ if is_Ascend950 ():
747+ compress_mask = compress_mask .to (torch .bool )
748+ key = self .key_cache .transpose (1 , 2 ) # type: ignore
749+ value = self .value_cache .transpose (1 , 2 ) # type: ignore
750+ block_size = self .block_size
751+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
752+ query = query ,
753+ key = key ,
754+ value = value ,
755+ block_table = block_table ,
756+ atten_mask = compress_mask ,
757+ actual_seq_qlen = attn_metadata .query_lens .cumsum (0 ),
758+ actual_seq_kvlen = attn_metadata .seq_lens ,
759+ num_query_heads = self .num_heads ,
760+ num_key_value_heads = self .num_kv_heads ,
761+ softmax_scale = self .scale ,
762+ spare_mode = 2 ,
763+ block_size = block_size ,
764+ input_layout = "TND" )
765+ return output
766+
732767 if block_size == 128 :
733768 # TODO:The npu_fused_infer_attention_score op is planned to
734769 # be utilized in a wider range in upcoming versions.
@@ -777,18 +812,20 @@ def _forward_decode_only(
777812 # seq_lens_tensor needs to be transferred to the device for 310P.
778813 attn_metadata .seq_lens = \
779814 attn_metadata .seq_lens .to (device = query .device )
815+
816+ batch_size = attn_metadata .seq_lens .shape [0 ]
817+ block_size = 128
818+ key = self .key_cache
819+ value = self .value_cache
820+ if self .key_cache is not None and self .value_cache is not None :
821+ block_size = self .key_cache .shape [1 ]
822+ key = self .key_cache .flatten (2 , 3 ).contiguous ()
823+ value = self .value_cache .flatten (2 , 3 ).contiguous ()
824+
780825 if self .sliding_window is not None and attn_metadata .seq_lens .shape [
781826 0 ] == query .size (0 ):
782- batch_size = attn_metadata .seq_lens .shape [0 ]
783- block_size = 128
784- query = query .view (batch_size , 1 , self .num_heads * self .head_size )
785- key = self .key_cache
786- value = self .value_cache
787- if self .key_cache is not None and self .value_cache is not None :
788- block_size = self .key_cache .shape [1 ]
789- key = self .key_cache .flatten (2 , 3 ).contiguous ()
790- value = self .value_cache .flatten (2 , 3 ).contiguous ()
791-
827+ query = query .view (batch_size , 1 ,
828+ self .num_heads * self .head_size )
792829 output , _ = torch_npu .npu_fused_infer_attention_score (
793830 query ,
794831 key ,
@@ -805,16 +842,33 @@ def _forward_decode_only(
805842
806843 output = output .view (batch_size , self .num_heads , self .head_size )
807844 else :
808- torch_npu ._npu_paged_attention (
809- query = query ,
810- key_cache = self .key_cache ,
811- value_cache = self .value_cache ,
812- num_kv_heads = self .num_kv_heads ,
813- num_heads = self .num_heads ,
814- scale_value = self .scale ,
815- block_table = attn_metadata .block_tables ,
816- context_lens = attn_metadata .seq_lens ,
817- out = output )
845+ if is_Ascend950 ():
846+ query = query .view (batch_size , 1 ,
847+ self .num_heads * self .head_size )
848+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
849+ query = query ,
850+ key = key ,
851+ value = value ,
852+ actual_seq_kvlen = attn_metadata .seq_lens ,
853+ num_query_heads = self .num_heads ,
854+ num_key_value_heads = self .num_kv_heads ,
855+ block_table = attn_metadata .block_tables [:batch_size ],
856+ block_size = block_size ,
857+ softmax_scale = self .scale ,
858+ input_layout = "BSH" )
859+ output = output .view (batch_size , self .num_heads ,
860+ self .head_size )
861+ else :
862+ torch_npu ._npu_paged_attention (
863+ query = query ,
864+ key_cache = self .key_cache ,
865+ value_cache = self .value_cache ,
866+ num_kv_heads = self .num_kv_heads ,
867+ num_heads = self .num_heads ,
868+ scale_value = self .scale ,
869+ block_table = attn_metadata .block_tables ,
870+ context_lens = attn_metadata .seq_lens ,
871+ out = output )
818872 return output
819873
820874 def _forward_v1_style (
@@ -862,7 +916,6 @@ def _forward_v1_style(
862916 num_block , block_size , - 1 )
863917 value = self .value_cache .view ( # type: ignore
864918 num_block , block_size , - 1 )
865-
866919 output , _ = torch_npu .npu_fused_infer_attention_score (
867920 query = query ,
868921 key = key ,
@@ -1507,12 +1560,20 @@ def forward(
15071560 if has_decode :
15081561 slot_mapping = attn_metadata .slot_mapping [:num_decode_tokens * self .pcp_size : self .pcp_size ] \
15091562 if self .pcp_size * self .dcp_size > 1 else attn_metadata .slot_mapping [:num_decode_tokens ]
1510- torch_npu ._npu_reshape_and_cache (
1511- key = key [:num_decode_tokens ],
1512- value = value [:num_decode_tokens ],
1513- key_cache = self .key_cache ,
1514- value_cache = self .value_cache ,
1515- slot_indices = slot_mapping )
1563+ if is_Ascend950 ():
1564+ num_tokens = slot_mapping .shape [0 ]
1565+ torch_npu .npu_scatter_pa_kv_cache (
1566+ key = key [:num_tokens ],
1567+ value = value [:num_tokens ].contiguous (),
1568+ slot_mapping = slot_mapping ,
1569+ out = (self .key_cache , self .value_cache ))
1570+ else :
1571+ torch_npu ._npu_reshape_and_cache (
1572+ key = key [:num_decode_tokens ],
1573+ value = value [:num_decode_tokens ],
1574+ key_cache = self .key_cache ,
1575+ value_cache = self .value_cache ,
1576+ slot_indices = slot_mapping )
15161577
15171578 if has_prefill :
15181579 if self .pcp_size > 1 :
@@ -1526,18 +1587,27 @@ def forward(
15261587 key , value = all_kv .split ([self .head_size , self .head_size ],
15271588 dim = - 1 )
15281589
1529- torch_npu ._npu_reshape_and_cache (
1530- key = key [self .pcp_size * num_decode_tokens :attn_metadata .
1531- num_actual_tokens_pcp_padded ],
1532- value = value [self .pcp_size *
1590+ if is_Ascend950 ():
1591+ num_tokens = attn_metadata .slot_mapping .shape [0 ]
1592+ torch_npu .npu_scatter_pa_kv_cache (
1593+ key = key [:num_tokens ],
1594+ value = value [:num_tokens ].contiguous (),
1595+ slot_mapping = attn_metadata .slot_mapping ,
1596+ out = (self .key_cache , self .value_cache ))
1597+ else :
1598+ torch_npu ._npu_reshape_and_cache (
1599+ key = key [self .pcp_size *
15331600 num_decode_tokens :attn_metadata .
15341601 num_actual_tokens_pcp_padded ],
1535- key_cache = self .key_cache ,
1536- value_cache = self .value_cache ,
1537- slot_indices = attn_metadata .
1538- slot_mapping [self .pcp_size *
1539- num_decode_tokens :attn_metadata .
1540- num_actual_tokens_pcp_padded ])
1602+ value = value [self .pcp_size *
1603+ num_decode_tokens :attn_metadata .
1604+ num_actual_tokens_pcp_padded ],
1605+ key_cache = self .key_cache ,
1606+ value_cache = self .value_cache ,
1607+ slot_indices = attn_metadata .
1608+ slot_mapping [self .pcp_size *
1609+ num_decode_tokens :attn_metadata .
1610+ num_actual_tokens_pcp_padded ])
15411611
15421612 forward_context : ForwardContext = get_forward_context ()
15431613 if not forward_context .capturing :
0 commit comments