Skip to content

Commit d42c0c5

Browse files
committed
account for scenario where keys and values are already sharded in tree attn decoding
1 parent f2f6d7a commit d42c0c5

File tree

4 files changed

+38
-26
lines changed

4 files changed

+38
-26
lines changed

ring_attention_pytorch/distributed.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial, lru_cache
2+
13
import torch
24
from torch import nn
35
from torch.nn import Module
@@ -20,6 +22,22 @@ def pad_dim_to(t, length, dim = 0):
2022
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
2123
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
2224

25+
cache = partial(lru_cache, maxsize = None)
26+
27+
# distributed helpers
28+
29+
@cache()
30+
def get_rank():
31+
return dist.get_rank() if dist.is_initialized() else 0
32+
33+
@cache()
34+
def get_world_size():
35+
return dist.get_world_size() if dist.is_initialized() else 1
36+
37+
@cache()
38+
def is_distributed():
39+
return dist.is_initialized() and dist.get_world_size() > 1
40+
2341
def all_gather_same_dim(t):
2442
t = t.contiguous()
2543
world_size = dist.get_world_size()

ring_attention_pytorch/ring.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import lru_cache, partial, wraps
3+
from functools import wraps, partial
44
from collections import namedtuple
55

66
import torch
@@ -9,6 +9,7 @@
99
from torch.autograd import Function
1010

1111
import torch.distributed as dist
12+
from ring_attention_pytorch.distributed import get_rank, get_world_size, is_distributed
1213

1314
# helper functions
1415

@@ -21,22 +22,6 @@ def default(v, d):
2122
def cast_tuple(t, length = 1):
2223
return t if isinstance(t, tuple) else ((t,) * length)
2324

24-
cache = partial(lru_cache, maxsize = None)
25-
26-
# distributed globals
27-
28-
@cache()
29-
def get_rank():
30-
return dist.get_rank() if dist.is_initialized() else 0
31-
32-
@cache()
33-
def get_world_size():
34-
return dist.get_world_size() if dist.is_initialized() else 1
35-
36-
@cache()
37-
def is_distributed():
38-
return dist.is_initialized() and dist.get_world_size() > 1
39-
4025
# ring functions
4126

4227
def circular_index_left(pos, ring_size, num = 1):

ring_attention_pytorch/tree_attn_decoding.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@
22
from torch import einsum
33
import torch.distributed as dist
44

5+
from ring_attention_pytorch.distributed import get_rank, get_world_size
6+
7+
def exists(v):
8+
return v is not None
9+
510
@torch.no_grad()
6-
def tree_attn_decode(q, k, v, eps = 1e-8):
11+
def tree_attn_decode(
12+
q, k, v,
13+
eps = 1e-8,
14+
shard_kv_seq = False
15+
):
716

817
assert k.shape[:-1] == v.shape[:-1]
918
assert q.shape[-2:] == (1, k.shape[-1])
@@ -16,8 +25,8 @@ def tree_attn_decode(q, k, v, eps = 1e-8):
1625

1726
device, dim_v = q.device, v.shape[-1]
1827

19-
rank = dist.get_rank() if dist.is_initialized() else 0
20-
world_size = dist.get_world_size() if dist.is_initialized() else 1
28+
rank = get_rank()
29+
world_size = get_world_size()
2130

2231
# scale queries
2332

@@ -26,12 +35,12 @@ def tree_attn_decode(q, k, v, eps = 1e-8):
2635

2736
# each machine (rank) takes care of a chunk of kv sequence within the world of many machines
2837

29-
k = k.chunk(world_size, dim = -2)
30-
v = v.chunk(world_size, dim = -2)
31-
32-
if rank < len(k):
33-
k, v = k[rank], v[rank]
38+
if shard_kv_seq:
39+
k = k.chunk(world_size, dim = -2)
40+
v = v.chunk(world_size, dim = -2)
41+
k, v = (k[rank], v[rank]) if rank < len(k) else (None, None)
3442

43+
if exists(k) and exists(v):
3544
# calculate local output and derive numerator and denominator
3645

3746
sim = einsum('... i d, ... j d -> ... i j', q, k)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ring-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.5',
6+
version = '0.5.6',
77
license='MIT',
88
description = 'Ring Attention - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)