@@ -59,7 +59,7 @@ def tree_attn_decode(
5959 if use_triton and q .is_cuda :
6060 from ring_attention_pytorch .triton_flash_attn import flash_attn_forward
6161
62- local_out , local_max , lse = flash_attn_forward (
62+ local_out , _ , lse = flash_attn_forward (
6363 q , k , v ,
6464 causal = False ,
6565 return_normalized_output = True ,
@@ -72,34 +72,25 @@ def tree_attn_decode(
7272 scale = q .shape [- 1 ] ** - 0.5
7373 sim = einsum ('... i d, ... j d -> ... i j' , q , k ) * scale
7474
75- local_max = sim .amax (dim = - 1 , keepdim = True )
76- sim -= local_max
7775 lse = sim .logsumexp (dim = - 1 , keepdim = True )
78-
7976 attn = sim .softmax (dim = - 1 )
8077 local_out = einsum ('... i j, ... j d -> ... i d' , attn , v )
8178
82- den = lse .exp ()
83- num = local_out .float () * den
84-
8579 else :
8680 # handle edge case where seq length < world size
8781
88- num = q .new_zeros ((* q .shape [:- 1 ], v .shape [- 1 ]), dtype = torch .float32 )
89- den = q .new_zeros ((* q .shape [:- 1 ], 1 ), dtype = torch .float32 )
90- local_max = torch .zeros_like (den )
82+ local_out = q .new_zeros ((* q .shape [:- 1 ], v .shape [- 1 ]), dtype = torch .float32 )
83+ lse = torch .full_like (den , - torch .finfo (torch .float32 ).max )
9184
9285 # first get global max through an all reduce (max)
9386
94- global_max = local_max .clone ()
95- dist .all_reduce (global_max , dist .ReduceOp .MAX )
87+ global_lse = lse .clone ()
88+ dist .all_reduce (global_lse , dist .ReduceOp .MAX )
9689
9790 # renormalize the numerator and denominators
9891
99- renorm_factor = (local_max - global_max ).exp ()
100-
101- den *= renorm_factor
102- num *= renorm_factor
92+ den = (lse - global_lse ).exp ()
93+ num = local_out * den
10394
10495 # second and third all reduce (sum)
10596
0 commit comments