2323import torch
2424import torch .distributed as dist
2525import torch .nn as nn
26- import torch .nn .functional as F
2726import torch_npu
2827from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
2928 AttentionLayer , AttentionType )
@@ -318,6 +317,18 @@ def build(
318317 pcp_metadata = None
319318 common_long_seq_metadata = common_attn_metadata .prefill_context_parallel_metadata
320319 if common_long_seq_metadata is not None :
320+ attn_mask_seqlens = common_long_seq_metadata .attn_mask_seqlens
321+ head_attn_nomask_seqlens = common_long_seq_metadata .head_attn_nomask_seqlens
322+ tail_attn_nomask_seqlens = common_long_seq_metadata .tail_attn_nomask_seqlens
323+ pcp_size = get_prefill_context_model_parallel_world_size (
324+ ) if prefill_context_parallel_enable () else 1
325+ if pcp_size > 1 :
326+ attn_mask_seqlens = torch .cumsum (attn_mask_seqlens [0 ],
327+ dim = 0 ).tolist ()
328+ head_attn_nomask_seqlens = torch .cumsum (
329+ head_attn_nomask_seqlens [1 ], dim = 0 ).tolist ()
330+ tail_attn_nomask_seqlens = torch .cumsum (
331+ tail_attn_nomask_seqlens [1 ], dim = 0 ).tolist ()
321332 pcp_metadata = AscendPCPMetadata (
322333 q_head_idx = common_long_seq_metadata .q_head_idx_tensor ,
323334 q_tail_idx = common_long_seq_metadata .q_tail_idx_tensor ,
@@ -329,12 +340,9 @@ def build(
329340 kv_with_q_tail_nomask_idx_tensor ,
330341 kv_with_q_tail_mask_idx = common_long_seq_metadata .
331342 kv_with_q_tail_mask_idx_tensor ,
332- attn_mask_seqlens = common_long_seq_metadata .
333- attn_mask_seqlens ,
334- head_attn_nomask_seqlens = common_long_seq_metadata .
335- head_attn_nomask_seqlens ,
336- tail_attn_nomask_seqlens = common_long_seq_metadata .
337- tail_attn_nomask_seqlens ,
343+ attn_mask_seqlens = attn_mask_seqlens ,
344+ head_attn_nomask_seqlens = head_attn_nomask_seqlens ,
345+ tail_attn_nomask_seqlens = tail_attn_nomask_seqlens ,
338346 q_full_idx = common_long_seq_metadata .q_full_idx ,
339347 pcp_prefill_mask = common_long_seq_metadata .pcp_prefill_mask )
340348 prefill_metadata = AscendMetadataForPrefill (
@@ -726,28 +734,6 @@ def _forward_v1_style(
726734 out = output )
727735 return output
728736
729- def _pack_tnd_2_bsnd (self , tensor_tnd : torch .Tensor ,
730- lengths : List [int ]) -> torch .Tensor :
731- max_len = max (lengths )
732- splits = torch .split (tensor_tnd , lengths , dim = 0 )
733-
734- padded = []
735- for s in splits :
736- pad_len = max_len - s .shape [0 ]
737- s_pad = F .pad (s , (0 , 0 , 0 , 0 , 0 , pad_len ))
738- padded .append (s_pad )
739-
740- tensor_bsnd = torch .stack (padded , dim = 0 )
741- return tensor_bsnd
742-
743- def _unpack_bsnd_2_tnd (self , tensor_bsnd : torch .Tensor ,
744- lengths : List [int ]) -> torch .Tensor :
745- slices = []
746- for i , length in enumerate (lengths ):
747- slices .append (tensor_bsnd [i , :length ])
748- tensor_tnd = torch .cat (slices , dim = 0 )
749- return tensor_tnd
750-
751737 def _attention_with_nomask_and_mask (self , q : torch .Tensor ,
752738 q_seqlens : List [int ],
753739 k_nomask : torch .Tensor ,
@@ -757,17 +743,15 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
757743 v_mask : torch .Tensor ,
758744 kv_seqlens_mask : List [int ],
759745 mask : torch .Tensor ) -> torch .Tensor :
760- q = self ._pack_tnd_2_bsnd (q , q_seqlens )
761-
762746 # nomask Attention
763747 if k_nomask is not None :
764748 attn_out_nomask , attn_lse_nomask = torch .ops .npu .npu_fused_infer_attention_score (
765749 q ,
766- self . _pack_tnd_2_bsnd ( k_nomask , kv_seqlens_nomask ) ,
767- self . _pack_tnd_2_bsnd ( v_nomask , kv_seqlens_nomask ) ,
750+ k_nomask ,
751+ v_nomask ,
768752 num_heads = self .num_heads ,
769753 num_key_value_heads = self .num_kv_heads ,
770- input_layout = "BSND " ,
754+ input_layout = "TND " ,
771755 atten_mask = None ,
772756 scale = self .scale ,
773757 sparse_mode = 0 ,
@@ -776,38 +760,46 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
776760 softmax_lse_flag = True ,
777761 actual_seq_lengths_kv = kv_seqlens_nomask ,
778762 actual_seq_lengths = q_seqlens )
779- attn_out_nomask = self ._unpack_bsnd_2_tnd (attn_out_nomask ,
780- q_seqlens )
781- # (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
782- attn_lse_nomask = self ._unpack_bsnd_2_tnd (
783- attn_lse_nomask .permute ([0 , 2 , 1 , 3 ]), q_seqlens )
784763
785764 # mask Attention
786765 attn_out_mask , attn_lse_mask = torch .ops .npu .npu_fused_infer_attention_score (
787766 q ,
788- self . _pack_tnd_2_bsnd ( k_mask , kv_seqlens_mask ) ,
789- self . _pack_tnd_2_bsnd ( v_mask , kv_seqlens_mask ) ,
767+ k_mask ,
768+ v_mask ,
790769 num_heads = self .num_heads ,
791770 num_key_value_heads = self .num_kv_heads ,
792- input_layout = "BSND " ,
771+ input_layout = "TND " ,
793772 atten_mask = mask ,
794773 scale = self .scale ,
795- sparse_mode = 0 ,
774+ sparse_mode = 3 ,
796775 antiquant_mode = 0 ,
797776 antiquant_scale = None ,
798777 softmax_lse_flag = True ,
799778 actual_seq_lengths_kv = kv_seqlens_mask ,
800779 actual_seq_lengths = q_seqlens )
801- attn_out_mask = self ._unpack_bsnd_2_tnd (attn_out_mask , q_seqlens )
802- attn_lse_mask = self ._unpack_bsnd_2_tnd (
803- attn_lse_mask .permute ([0 , 2 , 1 , 3 ]), q_seqlens )
804780
805781 # update
806782 output = attn_out_mask
807783 if k_nomask is not None :
808- output , _ = self ._update_out_and_lse (
809- torch .stack ([attn_out_nomask , attn_out_mask ], dim = 0 ),
810- torch .stack ([attn_lse_nomask , attn_lse_mask ], dim = 0 ))
784+ T = attn_out_mask .shape [0 ]
785+ N = attn_out_mask .shape [1 ]
786+ D = attn_out_mask .shape [2 ]
787+
788+ attn_out_mask , attn_lse_mask = self ._out_lse_reshape (
789+ attn_out_mask , attn_lse_mask )
790+ attn_out_nomask , attn_lse_nomask = self ._out_lse_reshape (
791+ attn_out_nomask , attn_lse_nomask )
792+ attn_out_mask = attn_out_mask .to (torch .float32 )
793+ attn_out_nomask = attn_out_nomask .to (torch .float32 )
794+ attn_lse_mask = attn_lse_mask .to (torch .float32 )
795+ attn_lse_nomask = attn_lse_nomask .to (torch .float32 )
796+
797+ attn_output = [attn_out_nomask , attn_out_mask ]
798+ attn_lse = [attn_lse_nomask , attn_lse_mask ]
799+ update_type = 0
800+ output , _ = torch_npu .npu_attention_update (attn_lse , attn_output ,
801+ update_type )
802+ output = output .view (T , N , D )
811803
812804 return output
813805
@@ -832,29 +824,29 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
832824 # 1. Attention calculation in the first half of Q in load balancing
833825 output_head = self ._attention_with_nomask_and_mask (
834826 q = torch .index_select (query , 0 , q_head_idx ),
835- q_seqlens = attn_mask_seqlens [ 0 ]. tolist () ,
827+ q_seqlens = attn_mask_seqlens ,
836828 k_nomask = torch .index_select (key , 0 , kv_with_q_head_nomask_idx )
837829 if self .pcp_rank > 0 else None ,
838830 v_nomask = torch .index_select (value , 0 , kv_with_q_head_nomask_idx )
839831 if self .pcp_rank > 0 else None ,
840- kv_seqlens_nomask = head_attn_nomask_seqlens [ 1 ]. tolist () ,
832+ kv_seqlens_nomask = head_attn_nomask_seqlens ,
841833 k_mask = torch .index_select (key , 0 , kv_with_q_head_mask_idx ),
842834 v_mask = torch .index_select (value , 0 , kv_with_q_head_mask_idx ),
843- kv_seqlens_mask = attn_mask_seqlens [ 0 ]. tolist () ,
835+ kv_seqlens_mask = attn_mask_seqlens ,
844836 mask = mask )
845837
846838 # 2. the Attention calculation in the latter half of Q in load balancing
847839 # pcp_rank0: Q3*KV0~KV2 + Q3*KV3
848840 # pcp_rank1: Q2*KV0~KV1 + Q2*KV2
849841 output_tail = self ._attention_with_nomask_and_mask (
850842 q = torch .index_select (query , 0 , q_tail_idx ),
851- q_seqlens = attn_mask_seqlens [ 0 ]. tolist () ,
843+ q_seqlens = attn_mask_seqlens ,
852844 k_nomask = torch .index_select (key , 0 , kv_with_q_tail_nomask_idx ),
853845 v_nomask = torch .index_select (value , 0 , kv_with_q_tail_nomask_idx ),
854- kv_seqlens_nomask = tail_attn_nomask_seqlens [ 1 ]. tolist () ,
846+ kv_seqlens_nomask = tail_attn_nomask_seqlens ,
855847 k_mask = torch .index_select (key , 0 , kv_with_q_tail_mask_idx ),
856848 v_mask = torch .index_select (value , 0 , kv_with_q_tail_mask_idx ),
857- kv_seqlens_mask = attn_mask_seqlens [ 0 ]. tolist () ,
849+ kv_seqlens_mask = attn_mask_seqlens ,
858850 mask = mask )
859851
860852 # 3. Combine the output of the first half and second half.
@@ -863,20 +855,36 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
863855 torch .cat ([output_head , output_tail ], dim = 0 ), 0 , q_full_idx )
864856 return output
865857
866- def _update_out_and_lse (self , out_list : torch .Tensor ,
867- lse_list : torch .Tensor ) -> torch .Tensor :
868- """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
869- Args:
870- out_list: shape = [N, batch_size, num_heads, head_size]
871- lse_list: shape = [N, batch_size, num_heads, 1]
872- Returns:
873- out_final: shape = [batch_size, num_heads, head_size]
874- lse_final: shape = [batch_size, num_heads, 1]
875- """
876- lse_final = torch .logsumexp (lse_list , dim = 0 , keepdim = False )
877- out_final = torch .sum (torch .exp (lse_list - lse_final ) * out_list ,
878- dim = 0 )
879- return out_final , lse_final
858+ def _out_lse_reshape (self , attn_out : torch .Tensor ,
859+ attn_lse : torch .Tensor ) -> torch .Tensor :
860+ attn_out = attn_out .contiguous ().view (
861+ attn_out .shape [0 ] * attn_out .shape [1 ], attn_out .shape [2 ])
862+ attn_lse = attn_lse .contiguous ().view (
863+ attn_lse .shape [0 ] * attn_lse .shape [1 ] * attn_lse .shape [2 ])
864+ return attn_out , attn_lse
865+
866+ def _npu_attention_update (
867+ self , attn_out_lse_list : List [torch .Tensor ]) -> torch .Tensor :
868+ update_type = 0
869+
870+ batch = attn_out_lse_list [0 ].shape [0 ]
871+ num_heads = attn_out_lse_list [0 ].shape [1 ]
872+ head_dim = attn_out_lse_list [0 ].shape [2 ] - 1
873+
874+ attn_out_split_cp = []
875+ attn_lse_split_cp = []
876+
877+ for i in attn_out_lse_list :
878+ attn_out_allgather , attn_lse_allgather = self ._out_lse_reshape (
879+ * torch .split (i , [self .head_size , 1 ], dim = - 1 ))
880+ attn_out_split_cp .append (attn_out_allgather )
881+ attn_lse_split_cp .append (attn_lse_allgather )
882+
883+ attn_out , attn_lse = torch_npu .npu_attention_update (
884+ attn_lse_split_cp , attn_out_split_cp , update_type )
885+ attn_out = attn_out .view (batch , num_heads , head_dim )
886+
887+ return attn_out
880888
881889 def _forward_decode_pcp_dcp (self , query : torch .Tensor ,
882890 attn_metadata : AscendMetadata ) -> torch .Tensor :
@@ -889,9 +897,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
889897 else :
890898 num_heads = self .num_heads
891899
892- # 1. Compute out&lse by "npu_fused_infer_attention_score"
893- q_nope = query .view (query .shape [0 ], 1 , query .shape [1 ], query .shape [2 ])
894- # [b,num_heads,head_size] -> [b,1,num_heads,head_size]
895900 k_nope = self .key_cache .view (self .key_cache .shape [0 ],
896901 self .key_cache .shape [1 ], - 1 )
897902 value = self .value_cache .view (self .key_cache .shape [0 ],
@@ -902,7 +907,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
902907 'num_key_value_heads' :
903908 self .num_kv_heads ,
904909 'input_layout' :
905- "BSND " ,
910+ "TND " ,
906911 'atten_mask' :
907912 None ,
908913 'scale' :
@@ -917,9 +922,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
917922 attn_metadata .block_tables ,
918923 'block_size' :
919924 self .key_cache .shape [1 ],
920- "actual_seq_lengths_kv" :
921- attn_metadata .decode_meta .
922- num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self .dcp_rank ],
925+ 'actual_seq_lengths_kv' :
926+ attn_metadata .seq_lens_list [:attn_metadata .num_decode_tokens ],
927+ 'actual_seq_lengths' :
928+ attn_metadata .actual_seq_lengths_q [:attn_metadata .
929+ num_decode_tokens ]
923930 }
924931 graph_params = get_graph_params ()
925932 forward_context : ForwardContext = get_forward_context ()
@@ -935,16 +942,16 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
935942 workspace = graph_params .workspaces .get (num_tokens )
936943 if workspace is None :
937944 workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
938- q_nope , k_nope , value , ** common_kwargs )
945+ query , k_nope , value , ** common_kwargs )
939946 update_graph_params_workspaces (num_tokens ,
940947 weak_ref_tensors (workspace ))
941- attn_out = torch .empty_like (q_nope )
948+ attn_out = torch .empty_like (query )
942949 attn_lse = torch .empty ((num_tokens , num_heads , 1 , 1 ),
943950 dtype = torch .float ,
944- device = q_nope .device )
951+ device = query .device )
945952
946953 graph_params .attn_params [num_tokens ].append (
947- (weak_ref_tensors (q_nope ), weak_ref_tensors (k_nope ),
954+ (weak_ref_tensors (query ), weak_ref_tensors (k_nope ),
948955 weak_ref_tensors (value ), self .num_heads , self .num_kv_heads ,
949956 self .scale , attn_metadata .block_tables ,
950957 self .key_cache .shape [1 ], attn_metadata .decode_meta .
@@ -954,7 +961,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
954961 self .pcp_rank , self .dcp_rank , self .dcp_size ))
955962 torch .npu .graph_task_group_begin (stream )
956963 torch_npu .npu_fused_infer_attention_score .out (
957- q_nope ,
964+ query ,
958965 k_nope ,
959966 value ,
960967 ** common_kwargs ,
@@ -964,14 +971,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
964971 graph_params .handles [num_tokens ].append (handle )
965972 else :
966973 attn_out , attn_lse = torch_npu .npu_fused_infer_attention_score (
967- q_nope , k_nope , value , ** common_kwargs )
974+ query , k_nope , value , ** common_kwargs )
968975
969- attn_out = attn_out . view ( attn_out . shape [ 0 ], attn_out . shape [ 2 ],
970- attn_out . shape [ 3 ])
971- attn_lse = attn_lse . view ( attn_lse . shape [ 0 ] , attn_lse . shape [ 1 ], 1 )
976+ attn_out_lse_list = []
977+ # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
978+ attn_out_lse = torch . cat ([ attn_out , attn_lse ], dim = - 1 )
972979 if self .dcp_size > 1 :
973- # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
974- attn_out_lse = torch .cat ([attn_out , attn_lse ], dim = - 1 )
975980 # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
976981 attn_out_lse = attn_out_lse .permute ([1 , 2 , 0 ]).contiguous ()
977982 attn_out_lse_all2all = torch .empty_like (attn_out_lse )
@@ -980,35 +985,28 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
980985 group = self .dcp_group )
981986 # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
982987 attn_out_lse_all2all = attn_out_lse_all2all .permute ([2 , 0 , 1 ])
983- attn_out_lse_split_on_seq = list (
988+ if self .pcp_size > 1 :
989+ attn_out_lse = attn_out_lse_all2all .contiguous ()
990+ attn_out_lse_list = list (
984991 torch .chunk (attn_out_lse_all2all , self .dcp_size , dim = 1 ))
985992
986- attn_out_lse_split_dcp = torch .stack (
987- attn_out_lse_split_on_seq ,
988- dim = 0 ) # [dcp, batch_size, num_heads, head_size+1]
989- # Update out&lse
990- attn_out_split_dcp , attn_lse_split_dcp = torch .split (
991- attn_out_lse_split_dcp , [self .head_size , 1 ], dim = - 1 )
992- attn_out , attn_lse = self ._update_out_and_lse (
993- attn_out_split_dcp , attn_lse_split_dcp )
994993 if self .pcp_size > 1 :
995- # 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
996- attn_out_lse = torch .cat ([attn_out , attn_lse ], dim = - 1 )
997- # 3. AllGather out&lse within CP group
994+ # AllGather out&lse within CP group
998995 attn_out_lse_list = [
999996 torch .empty_like (attn_out_lse ) for _ in range (self .pcp_size )
1000997 ]
1001998 dist .all_gather (attn_out_lse_list ,
1002999 attn_out_lse ,
10031000 group = self .pcp_group )
1004- # 4. Update out&lse
1005- attn_out_lse_allgather = torch .stack (
1006- attn_out_lse_list ,
1007- dim = 0 ) # [pcp, batch_size, num_heads, head_size+1]
1008- attn_out_allgather , attn_lse_allgather = torch .split (
1009- attn_out_lse_allgather , [self .head_size , 1 ], dim = - 1 )
1010- attn_out , _ = self ._update_out_and_lse (attn_out_allgather ,
1011- attn_lse_allgather )
1001+ if self .dcp_size > 1 and self .pcp_size > 1 :
1002+ attn_out_lse_list_pcp_dcp = []
1003+ for s in attn_out_lse_list :
1004+ attn_out_lse_list_split = list (
1005+ torch .chunk (s , self .dcp_size , dim = 1 ))
1006+ attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
1007+ attn_out_lse_list = attn_out_lse_list_pcp_dcp
1008+ # Update out&lse
1009+ attn_out = self ._npu_attention_update (attn_out_lse_list )
10121010 return attn_out
10131011
10141012 def _forward_pcp_dcp (self , query : torch .Tensor , key : torch .Tensor ,
0 commit comments