File tree Expand file tree Collapse file tree 3 files changed +41
-6
lines changed Expand file tree Collapse file tree 3 files changed +41
-6
lines changed Original file line number Diff line number Diff line change @@ -6,6 +6,28 @@ Implementation of <a href="https://arxiv.org/abs/2311.01927">GateLoop</a> Transf
66
77Jax version will be done with the <a href =" https://github.com/patrick-kidger/equinox " >Equinox</a > framework
88
9+ ## Install
10+
11+ ``` bash
12+ $ pip install gateloop-transformr
13+ ```
14+
15+ ## Usage
16+
17+ ``` python
18+ import torch
19+ from gateloop_transformer import Transformer
20+
21+ model = Transformer(
22+ num_tokens = 256 ,
23+ dim = 624 ,
24+ depth = 6
25+ )
26+
27+ ids = torch.randint(0 , 256 , (1 , 1024 ))
28+ logits = model(ids) # (1, 1024, 256)
29+ ```
30+
931## Citations
1032
1133``` bibtex
Original file line number Diff line number Diff line change @@ -197,23 +197,36 @@ def __init__(
197197 dim_head = 64 ,
198198 heads = 8 ,
199199 ff_mult = 4 ,
200+ use_gate_looped_attn = True ,
201+ dim_gate_looped_attn = None ,
200202 data_dependent_rel_pos = False ,
201- frac_gradient_data_dependent_rel_pos = 0.5
203+ frac_gradient_state_transition = 0.5
202204 ):
203205 super ().__init__ ()
204206
205207 self .token_emb = nn .Embedding (num_tokens , dim )
206208
207209 layers = ModuleList ([])
210+
208211 for _ in range (depth ):
209- layers .append (ModuleList ([
210- CausalFullAttention (
212+
213+ if use_gate_looped_attn :
214+ spatial_mixer = GateLoopedAttention (
215+ dim = dim ,
216+ dim_inner = dim_gate_looped_attn ,
217+ frac_gradient_state_transition = frac_gradient_state_transition
218+ )
219+ else :
220+ spatial_mixer = CausalFullAttention (
211221 dim = dim ,
212222 dim_head = dim_head ,
213223 heads = heads ,
214224 data_dependent_rel_pos = data_dependent_rel_pos ,
215- frac_gradient_data_dependent_rel_pos = frac_gradient_data_dependent_rel_pos
216- ),
225+ frac_gradient_data_dependent_rel_pos = frac_gradient_state_transition
226+ )
227+
228+ layers .append (ModuleList ([
229+ spatial_mixer ,
217230 FeedForward (
218231 dim = dim ,
219232 mult = ff_mult
Original file line number Diff line number Diff line change 33setup (
44 name = 'gateloop-transformer' ,
55 packages = find_packages (exclude = []),
6- version = '0.0.2 ' ,
6+ version = '0.0.3 ' ,
77 license = 'MIT' ,
88 description = 'GateLoop Transformer' ,
99 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments