Skip to content

Commit 42b9590

Browse files
committed
prepare gate loop transformer for experiments
1 parent 044c266 commit 42b9590

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,28 @@ Implementation of <a href="https://arxiv.org/abs/2311.01927">GateLoop</a> Transf
66

77
Jax version will be done with the <a href="https://github.com/patrick-kidger/equinox">Equinox</a> framework
88

9+
## Install
10+
11+
```bash
12+
$ pip install gateloop-transformr
13+
```
14+
15+
## Usage
16+
17+
```python
18+
import torch
19+
from gateloop_transformer import Transformer
20+
21+
model = Transformer(
22+
num_tokens = 256,
23+
dim = 624,
24+
depth = 6
25+
)
26+
27+
ids = torch.randint(0, 256, (1, 1024))
28+
logits = model(ids) # (1, 1024, 256)
29+
```
30+
931
## Citations
1032

1133
```bibtex

gateloop_transformer/gateloop_transformer.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,23 +197,36 @@ def __init__(
197197
dim_head = 64,
198198
heads = 8,
199199
ff_mult = 4,
200+
use_gate_looped_attn = True,
201+
dim_gate_looped_attn = None,
200202
data_dependent_rel_pos = False,
201-
frac_gradient_data_dependent_rel_pos = 0.5
203+
frac_gradient_state_transition = 0.5
202204
):
203205
super().__init__()
204206

205207
self.token_emb = nn.Embedding(num_tokens, dim)
206208

207209
layers = ModuleList([])
210+
208211
for _ in range(depth):
209-
layers.append(ModuleList([
210-
CausalFullAttention(
212+
213+
if use_gate_looped_attn:
214+
spatial_mixer = GateLoopedAttention(
215+
dim = dim,
216+
dim_inner = dim_gate_looped_attn,
217+
frac_gradient_state_transition = frac_gradient_state_transition
218+
)
219+
else:
220+
spatial_mixer = CausalFullAttention(
211221
dim = dim,
212222
dim_head = dim_head,
213223
heads = heads,
214224
data_dependent_rel_pos = data_dependent_rel_pos,
215-
frac_gradient_data_dependent_rel_pos = frac_gradient_data_dependent_rel_pos
216-
),
225+
frac_gradient_data_dependent_rel_pos = frac_gradient_state_transition
226+
)
227+
228+
layers.append(ModuleList([
229+
spatial_mixer,
217230
FeedForward(
218231
dim = dim,
219232
mult = ff_mult

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'gateloop-transformer',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.2',
6+
version = '0.0.3',
77
license='MIT',
88
description = 'GateLoop Transformer',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)