Skip to content

Commit 786b201

Browse files
committed
fix a misunderstanding, thanks to main author @tobiaskatsch for the discussion and code review
1 parent c919871 commit 786b201

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Update 3: Fixed a misunderstanding and definitely seems to be converging better
1212

1313
Update 4: <a href="https://api.wandb.ai/links/lucidrains/ysbz84fn">Ongoing experiments</a>
1414

15+
Update 5: Author has reviewed the code, and there was another misunderstanding. They use maximum heads (heads == dimension). This is kind of a plot twist, as this is infeasible for normal attention. It also obviates the need a fused CUDA kernel as in autoregressive linear attention.
16+
1517
### Install
1618

1719
```bash

gateloop_transformer/gateloop_transformer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class GateLoopedAttention(Module):
178178
def __init__(
179179
self,
180180
dim,
181+
heads = None,
181182
dim_inner = None,
182183
checkpoint_gate_looped_attn = True,
183184
frac_gradient_state_transition = 0.5
@@ -187,16 +188,21 @@ def __init__(
187188
self.checkpoint_gate_looped_attn = checkpoint_gate_looped_attn
188189

189190
dim_inner = default(dim_inner, dim)
191+
heads = default(heads, dim_inner)
190192

191193
self.norm = RMSNorm(dim)
192194

195+
self.heads = heads
196+
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
197+
193198
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
194199

195200
self.to_a = nn.Sequential(
196201
nn.Linear(dim, dim_inner * 2),
197-
Rearrange('... (d c) -> ... d c', c = 2)
202+
Rearrange('b n (h d c) -> (b h) n d c', h = heads, c = 2)
198203
)
199204

205+
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
200206
self.to_out = nn.Linear(dim_inner, dim, bias = False) if dim_inner != dim else nn.Identity()
201207

202208
def forward(
@@ -211,6 +217,8 @@ def forward(
211217

212218
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
213219

220+
q, k, v = map(self.split_heads, (q, k, v))
221+
214222
a = self.to_a(x)
215223
a = a * frac_gradient + a.detach() * (1 - frac_gradient)
216224

@@ -228,6 +236,7 @@ def forward(
228236

229237
out = fn(q, k, v, a)
230238

239+
out = self.merge_heads(out)
231240
return self.to_out(out)
232241

233242
# main class
@@ -244,6 +253,7 @@ def __init__(
244253
ff_mult = 4,
245254
checkpoint_gate_looped_attn = True,
246255
use_gate_looped_attn = True,
256+
gate_loop_heads = None,
247257
dim_gate_looped_attn = None,
248258
attn_softmax_normalize = None,
249259
data_dependent_rel_pos = False,
@@ -265,6 +275,7 @@ def __init__(
265275
if use_gate_looped_attn:
266276
spatial_mixer = GateLoopedAttention(
267277
dim = dim,
278+
heads = gate_loop_heads,
268279
dim_inner = dim_gate_looped_attn,
269280
checkpoint_gate_looped_attn = checkpoint_gate_looped_attn,
270281
frac_gradient_state_transition = frac_gradient_state_transition

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'gateloop-transformer',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.11',
6+
version = '0.0.12',
77
license='MIT',
88
description = 'GateLoop Transformer',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)