@@ -701,28 +701,6 @@ def _forward_v1_style(
701701 out = output )
702702 return output
703703
704- def _pack_tnd_2_bsnd (self , tensor_tnd : torch .Tensor ,
705- lengths : List [int ]) -> torch .Tensor :
706- max_len = max (lengths )
707- splits = torch .split (tensor_tnd , lengths , dim = 0 )
708-
709- padded = []
710- for s in splits :
711- pad_len = max_len - s .shape [0 ]
712- s_pad = F .pad (s , (0 , 0 , 0 , 0 , 0 , pad_len ))
713- padded .append (s_pad )
714-
715- tensor_bsnd = torch .stack (padded , dim = 0 )
716- return tensor_bsnd
717-
718- def _unpack_bsnd_2_tnd (self , tensor_bsnd : torch .Tensor ,
719- lengths : List [int ]) -> torch .Tensor :
720- slices = []
721- for i , length in enumerate (lengths ):
722- slices .append (tensor_bsnd [i , :length ])
723- tensor_tnd = torch .cat (slices , dim = 0 )
724- return tensor_tnd
725-
726704 def _attention_with_nomask_and_mask (self , q : torch .Tensor ,
727705 q_seqlens : List [int ],
728706 k_nomask : torch .Tensor ,
@@ -732,17 +710,15 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
732710 v_mask : torch .Tensor ,
733711 kv_seqlens_mask : List [int ],
734712 mask : torch .Tensor ) -> torch .Tensor :
735- q = self ._pack_tnd_2_bsnd (q , q_seqlens )
736-
737713 # nomask Attention
738714 if k_nomask is not None :
739715 attn_out_nomask , attn_lse_nomask = torch .ops .npu .npu_fused_infer_attention_score (
740716 q ,
741- self . _pack_tnd_2_bsnd ( k_nomask , kv_seqlens_nomask ) ,
742- self . _pack_tnd_2_bsnd ( v_nomask , kv_seqlens_nomask ) ,
717+ k_nomask ,
718+ v_nomask ,
743719 num_heads = self .num_heads ,
744720 num_key_value_heads = self .num_kv_heads ,
745- input_layout = "BSND " ,
721+ input_layout = "TND " ,
746722 atten_mask = None ,
747723 scale = self .scale ,
748724 sparse_mode = 0 ,
@@ -751,38 +727,46 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
751727 softmax_lse_flag = True ,
752728 actual_seq_lengths_kv = kv_seqlens_nomask ,
753729 actual_seq_lengths = q_seqlens )
754- attn_out_nomask = self ._unpack_bsnd_2_tnd (attn_out_nomask ,
755- q_seqlens )
756- # (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
757- attn_lse_nomask = self ._unpack_bsnd_2_tnd (
758- attn_lse_nomask .permute ([0 , 2 , 1 , 3 ]), q_seqlens )
759730
760731 # mask Attention
761732 attn_out_mask , attn_lse_mask = torch .ops .npu .npu_fused_infer_attention_score (
762733 q ,
763- self . _pack_tnd_2_bsnd ( k_mask , kv_seqlens_mask ) ,
764- self . _pack_tnd_2_bsnd ( v_mask , kv_seqlens_mask ) ,
734+ k_mask ,
735+ v_mask ,
765736 num_heads = self .num_heads ,
766737 num_key_value_heads = self .num_kv_heads ,
767- input_layout = "BSND " ,
738+ input_layout = "TND " ,
768739 atten_mask = mask ,
769740 scale = self .scale ,
770- sparse_mode = 0 ,
741+ sparse_mode = 3 ,
771742 antiquant_mode = 0 ,
772743 antiquant_scale = None ,
773744 softmax_lse_flag = True ,
774745 actual_seq_lengths_kv = kv_seqlens_mask ,
775746 actual_seq_lengths = q_seqlens )
776- attn_out_mask = self ._unpack_bsnd_2_tnd (attn_out_mask , q_seqlens )
777- attn_lse_mask = self ._unpack_bsnd_2_tnd (
778- attn_lse_mask .permute ([0 , 2 , 1 , 3 ]), q_seqlens )
779747
780748 # update
781749 output = attn_out_mask
782750 if k_nomask is not None :
783- output , _ = self ._update_out_and_lse (
784- torch .stack ([attn_out_nomask , attn_out_mask ], dim = 0 ),
785- torch .stack ([attn_lse_nomask , attn_lse_mask ], dim = 0 ))
751+ T = attn_out_mask .shape [0 ]
752+ N = attn_out_mask .shape [1 ]
753+ D = attn_out_mask .shape [2 ]
754+
755+ attn_out_mask , attn_lse_mask = self ._out_lse_reshape (
756+ attn_out_mask , attn_lse_mask )
757+ attn_out_nomask , attn_lse_nomask = self ._out_lse_reshape (
758+ attn_out_nomask , attn_lse_nomask )
759+ attn_out_mask = attn_out_mask .to (torch .float32 )
760+ attn_out_nomask = attn_out_nomask .to (torch .float32 )
761+ attn_lse_mask = attn_lse_mask .to (torch .float32 )
762+ attn_lse_nomask = attn_lse_nomask .to (torch .float32 )
763+
764+ attn_output = [attn_out_nomask , attn_out_mask ]
765+ attn_lse = [attn_lse_nomask , attn_lse_mask ]
766+ update_type = 0
767+ output , _ = torch_npu .npu_attention_update (attn_lse , attn_output ,
768+ update_type )
769+ output = output .view (T , N , D )
786770
787771 return output
788772
@@ -838,20 +822,36 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
838822 torch .cat ([output_head , output_tail ], dim = 0 ), 0 , q_full_idx )
839823 return output
840824
841- def _update_out_and_lse (self , out_list : torch .Tensor ,
842- lse_list : torch .Tensor ) -> torch .Tensor :
843- """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
844- Args:
845- out_list: shape = [N, batch_size, num_heads, head_size]
846- lse_list: shape = [N, batch_size, num_heads, 1]
847- Returns:
848- out_final: shape = [batch_size, num_heads, head_size]
849- lse_final: shape = [batch_size, num_heads, 1]
850- """
851- lse_final = torch .logsumexp (lse_list , dim = 0 , keepdim = False )
852- out_final = torch .sum (torch .exp (lse_list - lse_final ) * out_list ,
853- dim = 0 )
854- return out_final , lse_final
825+ def _out_lse_reshape (self , attn_out : torch .Tensor ,
826+ attn_lse : torch .Tensor ) -> torch .Tensor :
827+ attn_out = attn_out .contiguous ().view (
828+ attn_out .shape [0 ] * attn_out .shape [1 ], attn_out .shape [2 ])
829+ attn_lse = attn_lse .contiguous ().view (
830+ attn_lse .shape [0 ] * attn_lse .shape [1 ] * attn_lse .shape [2 ])
831+ return attn_out , attn_lse
832+
833+ def _npu_attention_update (
834+ self , attn_out_lse_list : List [torch .Tensor ]) -> torch .Tensor :
835+ update_type = 0
836+
837+ batch = attn_out_lse_list [0 ].shape [0 ]
838+ num_heads = attn_out_lse_list [0 ].shape [1 ]
839+ head_dim = attn_out_lse_list [0 ].shape [2 ] - 1
840+
841+ attn_out_split_cp = []
842+ attn_lse_split_cp = []
843+
844+ for i in attn_out_lse_list :
845+ attn_out_allgather , attn_lse_allgather = self ._out_lse_reshape (
846+ * torch .split (i , [self .head_size , 1 ], dim = - 1 ))
847+ attn_out_split_cp .append (attn_out_allgather )
848+ attn_lse_split_cp .append (attn_lse_allgather )
849+
850+ attn_out , attn_lse = torch_npu .npu_attention_update (
851+ attn_lse_split_cp , attn_out_split_cp , update_type )
852+ attn_out = attn_out .view (batch , num_heads , head_dim )
853+
854+ return attn_out
855855
856856 def _forward_decode_pcp_dcp (self , query : torch .Tensor ,
857857 attn_metadata : AscendMetadata ) -> torch .Tensor :
@@ -864,9 +864,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
864864 else :
865865 num_heads = self .num_heads
866866
867- # 1. Compute out&lse by "npu_fused_infer_attention_score"
868- q_nope = query .view (query .shape [0 ], 1 , query .shape [1 ], query .shape [2 ])
869- # [b,num_heads,head_size] -> [b,1,num_heads,head_size]
870867 k_nope = self .key_cache .view (self .key_cache .shape [0 ],
871868 self .key_cache .shape [1 ], - 1 )
872869 value = self .value_cache .view (self .key_cache .shape [0 ],
@@ -877,7 +874,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
877874 'num_key_value_heads' :
878875 self .num_kv_heads ,
879876 'input_layout' :
880- "BSND " ,
877+ "TND " ,
881878 'atten_mask' :
882879 None ,
883880 'scale' :
@@ -892,9 +889,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
892889 attn_metadata .block_tables ,
893890 'block_size' :
894891 self .key_cache .shape [1 ],
895- "actual_seq_lengths_kv" :
896- attn_metadata .decode_meta .
897- num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self .dcp_rank ],
892+ 'actual_seq_lengths_kv' :
893+ attn_metadata .seq_lens_list [:attn_metadata .num_decode_tokens ],
894+ 'actual_seq_lengths' :
895+ attn_metadata .actual_seq_lengths_q [:attn_metadata .
896+ num_decode_tokens ]
898897 }
899898 graph_params = get_graph_params ()
900899 forward_context : ForwardContext = get_forward_context ()
@@ -910,26 +909,23 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
910909 workspace = graph_params .workspaces .get (num_tokens )
911910 if workspace is None :
912911 workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
913- q_nope , k_nope , value , ** common_kwargs )
914- update_graph_params_workspaces (num_tokens ,
915- weak_ref_tensors (workspace ))
916- attn_out = torch .empty_like (q_nope )
912+ query , k_nope , value , ** common_kwargs )
913+ graph_params .workspaces [num_tokens ] = workspace
914+ attn_out = torch .empty_like (query )
917915 attn_lse = torch .empty ((num_tokens , num_heads , 1 , 1 ),
918916 dtype = torch .float ,
919- device = q_nope .device )
917+ device = query .device )
920918
921919 graph_params .attn_params [num_tokens ].append (
922- (weak_ref_tensors (q_nope ), weak_ref_tensors (k_nope ),
923- weak_ref_tensors (value ), self .num_heads , self .num_kv_heads ,
920+ (query , k_nope , value , self .num_heads , self .num_kv_heads ,
924921 self .scale , attn_metadata .block_tables ,
925- self .key_cache .shape [1 ], attn_metadata .decode_meta .
926- num_computed_tokens_of_pcp_dcp [:, self .pcp_rank ,
927- self .dcp_rank ],
928- weak_ref_tensors (attn_out ), weak_ref_tensors (attn_lse ),
929- self .pcp_rank , self .dcp_rank , self .dcp_size ))
922+ self .key_cache .shape [1 ], attn_metadata .decode .
923+ num_computed_tokens_of_pcp_dcp [:, self .pcp_rank , self .dcp_rank ],
924+ workspace , attn_out , attn_lse , self .pcp_rank , self .dcp_rank ,
925+ self .dcp_size ))
930926 torch .npu .graph_task_group_begin (stream )
931927 torch_npu .npu_fused_infer_attention_score .out (
932- q_nope ,
928+ query ,
933929 k_nope ,
934930 value ,
935931 ** common_kwargs ,
@@ -939,14 +935,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
939935 graph_params .handles [num_tokens ].append (handle )
940936 else :
941937 attn_out , attn_lse = torch_npu .npu_fused_infer_attention_score (
942- q_nope , k_nope , value , ** common_kwargs )
938+ query , k_nope , value , ** common_kwargs )
943939
944- attn_out = attn_out . view ( attn_out . shape [ 0 ], attn_out . shape [ 2 ],
945- attn_out . shape [ 3 ])
946- attn_lse = attn_lse . view ( attn_lse . shape [ 0 ] , attn_lse . shape [ 1 ], 1 )
940+ attn_out_lse_list = []
941+ # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
942+ attn_out_lse = torch . cat ([ attn_out , attn_lse ], dim = - 1 )
947943 if self .dcp_size > 1 :
948- # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
949- attn_out_lse = torch .cat ([attn_out , attn_lse ], dim = - 1 )
950944 # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
951945 attn_out_lse = attn_out_lse .permute ([1 , 2 , 0 ]).contiguous ()
952946 attn_out_lse_all2all = torch .empty_like (attn_out_lse )
@@ -955,37 +949,29 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
955949 group = self .dcp_group )
956950 # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
957951 attn_out_lse_all2all = attn_out_lse_all2all .permute ([2 , 0 , 1 ])
958- attn_out_lse_split_on_seq = list (
952+ if self .pcp_size > 1 :
953+ attn_out_lse = attn_out_lse_all2all .contiguous ()
954+ attn_out_lse_list = list (
959955 torch .chunk (attn_out_lse_all2all , self .dcp_size , dim = 1 ))
960956
961- attn_out_lse_split_dcp = torch .stack (
962- attn_out_lse_split_on_seq ,
963- dim = 0 ) # [dcp, batch_size, num_heads, head_size+1]
964- # Update out&lse
965- attn_out_split_dcp , attn_lse_split_dcp = torch .split (
966- attn_out_lse_split_dcp , [self .head_size , 1 ], dim = - 1 )
967- attn_out , attn_lse = self ._update_out_and_lse (
968- attn_out_split_dcp , attn_lse_split_dcp )
969957 if self .pcp_size > 1 :
970- # 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
971- attn_out_lse = torch .cat ([attn_out , attn_lse ], dim = - 1 )
972- # 3. AllGather out&lse within CP group
958+ # AllGather out&lse within CP group
973959 attn_out_lse_list = [
974960 torch .empty_like (attn_out_lse ) for _ in range (self .pcp_size )
975961 ]
976962 dist .all_gather (attn_out_lse_list ,
977963 attn_out_lse ,
978964 group = self .pcp_group )
979- # 4. Update out&lse
980- attn_out_lse_allgather = torch .stack (
981- attn_out_lse_list ,
982- dim = 0 ) # [pcp, batch_size, num_heads, head_size+1]
983- attn_out_allgather , attn_lse_allgather = torch .split (
984- attn_out_lse_allgather , [self .head_size , 1 ], dim = - 1 )
985- attn_out , _ = self ._update_out_and_lse (attn_out_allgather ,
986- attn_lse_allgather )
965+ if self .dcp_size > 1 and self .pcp_size > 1 :
966+ attn_out_lse_list_pcp_dcp = []
967+ for s in attn_out_lse_list :
968+ attn_out_lse_list_split = list (
969+ torch .chunk (s , self .dcp_size , dim = 1 ))
970+ attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
971+ attn_out_lse_list = attn_out_lse_list_pcp_dcp
972+ # Update out&lse
973+ attn_out = self ._npu_attention_update (attn_out_lse_list )
987974 return attn_out
988-
989975 def _forward_pcp_dcp (self , query : torch .Tensor , key : torch .Tensor ,
990976 value : torch .Tensor , attn_metadata : AscendMetadata ,
991977 output : torch .Tensor ) -> torch .Tensor :
0 commit comments