Skip to content

Commit ab57c27

Browse files
committed
fix issue with seq len < world size in tree attn decoding after new update, also start using tensor typing
1 parent f3ed323 commit ab57c27

File tree

4 files changed

+47
-18
lines changed

4 files changed

+47
-18
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from torch import Tensor
2+
3+
from jaxtyping import (
4+
Float,
5+
Int,
6+
Bool
7+
)
8+
9+
# jaxtyping is a misnomer, works for pytorch
10+
11+
class TorchTyping:
12+
def __init__(self, abstract_dtype):
13+
self.abstract_dtype = abstract_dtype
14+
15+
def __getitem__(self, shapes: str):
16+
return self.abstract_dtype[Tensor, shapes]
17+
18+
Float = TorchTyping(Float)
19+
Int = TorchTyping(Int)
20+
Bool = TorchTyping(Bool)
21+
22+
__all__ = [
23+
Float,
24+
Int,
25+
Bool
26+
]

ring_attention_pytorch/tree_attn_decoding.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
from torch import einsum, Tensor
55
import torch.distributed as dist
66

7+
from einops import rearrange
8+
79
from ring_attention_pytorch.distributed import get_rank, get_world_size
810

11+
from ring_attention_pytorch.tensor_typing import Float
12+
913
# functions
1014

1115
def exists(v):
@@ -18,21 +22,17 @@ def default(v, d):
1822

1923
@torch.no_grad()
2024
def tree_attn_decode(
21-
q: Tensor,
22-
k: Tensor | None = None,
23-
v: Tensor | None = None,
25+
q: Float['b h 1 d'],
26+
k: Float['b h n d'] | None = None,
27+
v: Float['b h n dv'] | None = None,
2428
eps = 1e-8,
25-
shard_kv_seq = False,
29+
shard_kv_seq = True,
2630
use_triton = None
27-
):
28-
dtype = q.dtype
31+
) -> Float['b h 1 dv']:
2932

30-
assert not (exists(k) ^ exists(v)), 'keys and values are either both None, or both present'
33+
q_prec_dims, dtype = q.shape[:-1], q.dtype
3134

32-
if exists(k):
33-
assert k.shape[:-1] == v.shape[:-1]
34-
assert q.shape[-2:] == (1, k.shape[-1])
35-
assert q.shape[:-2] == k.shape[:-2]
35+
assert not (exists(k) ^ exists(v)), 'keys and values are either both None, or both present'
3636

3737
"""
3838
Algorithm 3 proposed in Tree Attention
@@ -43,6 +43,7 @@ def tree_attn_decode(
4343

4444
if shard_kv_seq:
4545
assert exists(k), 'keys and values must be passed if not already sharded across sequence'
46+
dim_v = v.shape[-1]
4647

4748
rank, world_size = get_rank(), get_world_size()
4849
k = k.chunk(world_size, dim = -2)
@@ -68,6 +69,7 @@ def tree_attn_decode(
6869
remove_padding = True
6970
)
7071

72+
lse = rearrange(lse, '... -> ... 1')
7173
else:
7274
scale = q.shape[-1] ** -0.5
7375
sim = einsum('... i d, ... j d -> ... i j', q, k) * scale
@@ -79,8 +81,8 @@ def tree_attn_decode(
7981
else:
8082
# handle edge case where seq length < world size
8183

82-
local_out = q.new_zeros((*q.shape[:-1], v.shape[-1]), dtype = torch.float32)
83-
lse = torch.full_like(den, -torch.finfo(torch.float32).max)
84+
local_out = q.new_zeros((*q_prec_dims, dim_v), dtype = torch.float32)
85+
lse = torch.full((*q_prec_dims, 1), -torch.finfo(torch.float32).max, device = q.device, dtype = torch.float32)
8486

8587
# first get max(lse) through an all reduce
8688

ring_attention_pytorch/triton_flash_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,10 @@ def flash_attn_forward(
322322
q, k, v = [x if is_contiguous(x) else x.contiguous() for x in (q, k, v)]
323323

324324
if head_first_dim:
325-
q, k, v = tuple(rearrange(t, 'b n h d -> b h n d') for t in (q, k, v))
325+
q, k, v = tuple(rearrange(t, 'b h n d -> b n h d') for t in (q, k, v))
326326

327327
if exists(o):
328-
o = rearrange(o, 'b n h d -> b h n d')
328+
o = rearrange(o, 'b h n d -> b n h d')
329329

330330
batch, seqlen_q, nheads, d = q.shape
331331
_, seqlen_k, _, _ = k.shape
@@ -421,7 +421,7 @@ def flash_attn_forward(
421421
)
422422

423423
if head_first_dim:
424-
o = rearrange(o, 'b h n d -> b n h d')
424+
o = rearrange(o, 'b n h d -> b h n d')
425425

426426
if remove_padding:
427427
m = m[..., :seqlen_q]

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ring-attention-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.12',
6+
version = '0.5.17',
77
license='MIT',
88
description = 'Ring Attention - Pytorch',
99
author = 'Phil Wang',
@@ -18,7 +18,8 @@
1818
install_requires=[
1919
'beartype',
2020
'einops>=0.8.0',
21-
'torch>=2.0'
21+
'jaxtyping',
22+
'torch>=2.0',
2223
],
2324
classifiers=[
2425
'Development Status :: 4 - Beta',

0 commit comments

Comments
 (0)