Skip to content

Commit 055f6d7

Browse files
committed
make tree attention decoding work with triton flash attention forward
1 parent d42c0c5 commit 055f6d7

File tree

4 files changed

+58
-23
lines changed

4 files changed

+58
-23
lines changed

assert_tree_attn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def start(
5757

5858
# inputs
5959

60-
q = torch.randn(batch, heads, 1, dim)
61-
k = torch.randn(batch, heads, seq_len, dim)
62-
v = torch.randn(batch, heads, seq_len, dim)
60+
q = torch.randn(batch, heads, 1, dim).half()
61+
k = torch.randn(batch, heads, seq_len, dim).half()
62+
v = torch.randn(batch, heads, seq_len, dim).half()
6363

6464
if use_cuda:
6565
q, k, v = tuple(t.cuda(rank) for t in (q, k, v))
@@ -75,6 +75,8 @@ def start(
7575
out = regular_decode(q, k, v)
7676
tree_out = tree_attn_decode(q, k, v)
7777

78+
out = out.to(tree_out.dtype)
79+
7880
# if not main early return
7981

8082
if not is_main:
@@ -95,7 +97,7 @@ def start(
9597

9698
@click.command()
9799
@click.option('--world-size', default = 8, help = 'number of machines / processes')
98-
@click.option('--dim', default = 512, help = 'dimension')
100+
@click.option('--dim', default = 64, help = 'dimension')
99101
@click.option('--heads', default = 8, help = 'dimension')
100102
@click.option('--batch', default = 1, help = 'dimension')
101103
@click.option('--use-cuda', is_flag = True, help = 'whether to test with CUDA and NCCL')

ring_attention_pytorch/tree_attn_decoding.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44

55
from ring_attention_pytorch.distributed import get_rank, get_world_size
66

7+
# functions
8+
79
def exists(v):
810
return v is not None
911

12+
def default(v, d):
13+
return v if exists(v) else d
14+
15+
# main function
16+
1017
@torch.no_grad()
1118
def tree_attn_decode(
1219
q, k, v,
1320
eps = 1e-8,
14-
shard_kv_seq = False
21+
shard_kv_seq = False,
22+
use_triton = None
1523
):
1624

1725
assert k.shape[:-1] == v.shape[:-1]
@@ -23,34 +31,44 @@ def tree_attn_decode(
2331
https://arxiv.org/abs/2408.04093
2432
"""
2533

26-
device, dim_v = q.device, v.shape[-1]
27-
28-
rank = get_rank()
29-
world_size = get_world_size()
30-
31-
# scale queries
32-
33-
scale = q.shape[-1] ** -0.5
34-
q *= scale
34+
dim_v = v.shape[-1]
3535

3636
# each machine (rank) takes care of a chunk of kv sequence within the world of many machines
3737

3838
if shard_kv_seq:
39+
rank, world_size = get_rank(), get_world_size()
3940
k = k.chunk(world_size, dim = -2)
4041
v = v.chunk(world_size, dim = -2)
42+
4143
k, v = (k[rank], v[rank]) if rank < len(k) else (None, None)
4244

4345
if exists(k) and exists(v):
4446
# calculate local output and derive numerator and denominator
4547

46-
sim = einsum('... i d, ... j d -> ... i j', q, k)
48+
use_triton = default(use_triton, q.is_cuda)
49+
50+
if use_triton and q.is_cuda:
51+
from ring_attention_pytorch.triton_flash_attn import flash_attn_forward
52+
53+
out, local_max, lse = flash_attn_forward(
54+
q, k, v,
55+
causal = False,
56+
return_normalized_output = True,
57+
load_accumulated = False,
58+
head_first_dim = True,
59+
remove_padding = True
60+
)
61+
62+
else:
63+
scale = q.shape[-1] ** -0.5
64+
sim = einsum('... i d, ... j d -> ... i j', q, k) * scale
4765

48-
local_max = sim.amax(dim = -1, keepdim = True)
49-
sim -= local_max
50-
lse = sim.logsumexp(dim = -1, keepdim = True)
66+
local_max = sim.amax(dim = -1, keepdim = True)
67+
sim -= local_max
68+
lse = sim.logsumexp(dim = -1, keepdim = True)
5169

52-
attn = sim.softmax(dim = -1)
53-
out = einsum('... i j, ... j d -> ... i d', attn, v)
70+
attn = sim.softmax(dim = -1)
71+
out = einsum('... i j, ... j d -> ... i d', attn, v)
5472

5573
den = lse.exp()
5674
num = out * den

ring_attention_pytorch/triton_flash_attn.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from torch import Tensor
1010

11-
from einops import repeat
11+
from einops import repeat, rearrange
1212

1313
def exists(v):
1414
return v is not None
@@ -315,10 +315,18 @@ def flash_attn_forward(
315315
return_normalized_output = False,
316316
load_accumulated = True,
317317
softclamp_qk_sim = False,
318-
softclamp_value = 50.
318+
softclamp_value = 50.,
319+
head_first_dim = False,
320+
remove_padding = False
319321
):
320322
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]
321323

324+
if head_first_dim:
325+
q, k, v = tuple(rearrange(t, 'b n h d -> b h n d') for t in (q, k, v))
326+
327+
if exists(o):
328+
o = rearrange(o, 'b n h d -> b h n d')
329+
322330
batch, seqlen_q, nheads, d = q.shape
323331
_, seqlen_k, _, _ = k.shape
324332

@@ -412,6 +420,13 @@ def flash_attn_forward(
412420
num_stages = 1,
413421
)
414422

423+
if head_first_dim:
424+
o = rearrange(o, 'b h n d -> b n h d')
425+
426+
if remove_padding:
427+
m = m[..., :seqlen_q]
428+
lse = lse[..., :seqlen_q]
429+
415430
return o, m, lse
416431

417432
@triton.jit

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ring-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.6',
6+
version = '0.5.8',
77
license='MIT',
88
description = 'Ring Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)