Skip to content

Commit 7263045

Browse files
committed
updated algorithm 3 in tree attn decoding paper is more concise
1 parent d49499f commit 7263045

File tree

2 files changed

+8
-17
lines changed

2 files changed

+8
-17
lines changed

ring_attention_pytorch/tree_attn_decoding.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def tree_attn_decode(
5959
if use_triton and q.is_cuda:
6060
from ring_attention_pytorch.triton_flash_attn import flash_attn_forward
6161

62-
local_out, local_max, lse = flash_attn_forward(
62+
local_out, _, lse = flash_attn_forward(
6363
q, k, v,
6464
causal = False,
6565
return_normalized_output = True,
@@ -72,34 +72,25 @@ def tree_attn_decode(
7272
scale = q.shape[-1] ** -0.5
7373
sim = einsum('... i d, ... j d -> ... i j', q, k) * scale
7474

75-
local_max = sim.amax(dim = -1, keepdim = True)
76-
sim -= local_max
7775
lse = sim.logsumexp(dim = -1, keepdim = True)
78-
7976
attn = sim.softmax(dim = -1)
8077
local_out = einsum('... i j, ... j d -> ... i d', attn, v)
8178

82-
den = lse.exp()
83-
num = local_out.float() * den
84-
8579
else:
8680
# handle edge case where seq length < world size
8781

88-
num = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
89-
den = q.new_zeros((*q.shape[:-1], 1), dtype = torch.float32)
90-
local_max = torch.zeros_like(den)
82+
local_out = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
83+
lse = torch.full_like(den, -torch.finfo(torch.float32).max)
9184

9285
# first get global max through an all reduce (max)
9386

94-
global_max = local_max.clone()
95-
dist.all_reduce(global_max, dist.ReduceOp.MAX)
87+
global_lse = lse.clone()
88+
dist.all_reduce(global_lse, dist.ReduceOp.MAX)
9689

9790
# renormalize the numerator and denominators
9891

99-
renorm_factor = (local_max - global_max).exp()
100-
101-
den *= renorm_factor
102-
num *= renorm_factor
92+
den = (lse - global_lse).exp()
93+
num = local_out * den
10394

10495
# second and third all reduce (sum)
10596

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ring-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.10',
6+
version = '0.5.12',
77
license='MIT',
88
description = 'Ring Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)