Skip to content

Commit 4932c5f

Browse files
committed
take care of caching for simple gateloop
1 parent 652e8b6 commit 4932c5f

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

gateloop_transformer/simplified_gate_loop.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44

55
from typing import Tuple
66

7-
from einops import rearrange
7+
from einops import rearrange, pack, unpack
88
from einops.layers.torch import Rearrange
99

1010
from gateloop_transformer.gateloop_transformer import RMSNorm
1111
from gateloop_transformer.associative_scan import associative_scan
1212

1313
# plain pytorch non-fused associative scan
1414

15-
def gate_loop_operator(q, kv, a):
15+
def exists(v):
16+
return v is not None
17+
18+
def gate_loop_operator(q, kv, a, cache = None):
1619

1720
@torch.jit.script
1821
def binary_operator(
@@ -23,9 +26,18 @@ def binary_operator(
2326
a_j, kv_j = b
2427
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
2528

26-
_, kv = associative_scan(binary_operator, (a, kv))
29+
if exists(cache):
30+
cache_a, cache_kv = cache
31+
a, a_ps = pack([cache_a, a], 'b * d')
32+
kv, kv_ps = pack([cache_kv, kv], 'b * d')
33+
34+
a, kv = associative_scan(binary_operator, (a, kv))
35+
36+
if exists(cache):
37+
_, a = unpack(a, a_ps, 'b * d')
38+
_, kv = unpack(kv, kv_ps, 'b * d')
2739

28-
return q * kv
40+
return q * kv, (a[:, -1], kv[:, -1])
2941

3042
# using jax associative scan
3143

@@ -48,7 +60,7 @@ def binary_operator(e_i, e_j):
4860

4961
return q * y
5062

51-
return jax2torch(jax_gate_loop_operator)
63+
return jax2torch(jax_gate_loop_operator), None
5264

5365
# simple gate loop layer
5466

@@ -75,6 +87,8 @@ def __init__(
7587
Rearrange('b n (qkva d) -> qkva (b d) n 1', qkva = 3)
7688
)
7789

90+
self.use_jax = use_jax_associative_scan
91+
7892
if use_jax_associative_scan:
7993
self.gate_loop_fn = get_jax_gate_loop_operator()
8094
else:
@@ -84,20 +98,30 @@ def __init__(
8498

8599
self.reverse = reverse
86100

87-
def forward(self, x):
88-
101+
def forward(
102+
self,
103+
x,
104+
cache = None,
105+
return_cache = False
106+
):
89107
if self.reverse:
90108
x = torch.flip(x, dims = (-2,))
91109

92110
x = self.norm(x)
93111

94112
q, kv, a = self.to_qkva(x)
95113

96-
out = self.gate_loop_fn(q, kv, a.sigmoid())
114+
out, cache = self.gate_loop_fn(q, kv, a.sigmoid(), cache = cache)
97115

98116
out = self.split_heads(out)
99117

100118
if self.reverse:
101119
out = torch.flip(out, dims = (-2,))
102120

103-
return out
121+
if not return_cache:
122+
return out
123+
124+
assert not self.reverse, 'caching only works with non-reversed seq'
125+
assert not self.use_jax, 'jax associative scan does not have caching yet'
126+
127+
return out, cache

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'gateloop-transformer',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.1',
6+
version = '0.1.4',
77
license='MIT',
88
description = 'GateLoop Transformer',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)