Skip to content

Commit 0a4b809

Browse files
committed
a simplified gateloop for personal use in other projects
1 parent d170c86 commit 0a4b809

File tree

4 files changed

+52
-9
lines changed

4 files changed

+52
-9
lines changed

gateloop_transformer/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33
GateLoopedAttention,
44
Transformer
55
)
6+
7+
from gateloop_transformer.simplified_gate_loop import (
8+
SimpleGateLoopLayer
9+
)

gateloop_transformer/gateloop_transformer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,6 @@ def forward(
189189

190190
# data gated linear attention with "gateloop operator"
191191

192-
def maybe_real(t):
193-
if not torch.is_complex(t):
194-
return t
195-
196-
return t.real
197-
198192
def gate_loop_operator(q, k, v, a):
199193
"""
200194
the pseudocode in section 3.2 of the paper
@@ -205,8 +199,7 @@ def gate_loop_operator(q, k, v, a):
205199
def binary_operator(a, b):
206200
a_i, kv_i = a
207201
a_j, kv_j = b
208-
209-
return a_j * a_i, maybe_real(a_j) * kv_i + kv_j
202+
return a_j * a_i, a_j.real * kv_i + kv_j
210203

211204
_, kv = associative_scan(binary_operator, (a, kv))
212205

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from torch import nn
2+
from torch.nn import Module
3+
4+
from einops import rearrange
5+
from einops.layers.torch import Rearrange
6+
7+
from gateloop_transformer.gateloop_transformer import RMSNorm
8+
from gateloop_transformer.associative_scan import associative_scan
9+
10+
def gate_loop_operator(q, kv, a):
11+
def binary_operator(a, b):
12+
a_i, kv_i = a
13+
a_j, kv_j = b
14+
return a_j * a_i, a_j * kv_i + kv_j
15+
16+
_, kv = associative_scan(binary_operator, (a, kv))
17+
18+
return q * kv
19+
20+
class SimpleGateLoopLayer(Module):
21+
"""
22+
simplified gate loop
23+
seeing if it can supplement attention as shown in https://github.com/lucidrains/mega-pytorch
24+
"""
25+
26+
def __init__(self, dim):
27+
super().__init__()
28+
self.norm = RMSNorm(dim)
29+
30+
self.dim = dim
31+
32+
self.to_qkva = nn.Sequential(
33+
nn.Linear(dim, dim * 3, bias = False),
34+
Rearrange('b n (qkva d) -> qkva (b d) n 1', qkva = 3)
35+
)
36+
37+
self.split_heads = Rearrange('(b d) n 1 -> b n d', d = dim)
38+
39+
def forward(self, x):
40+
x = self.norm(x)
41+
42+
q, kv, a = self.to_qkva(x)
43+
44+
out = gate_loop_operator(q, kv, a.sigmoid())
45+
46+
return self.split_heads(out)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'gateloop-transformer',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.23',
6+
version = '0.0.24',
77
license='MIT',
88
description = 'GateLoop Transformer',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)