Skip to content

Commit 1c00da3

Browse files
committed
add post ln as an option for simple gateloop
1 parent 3d83895 commit 1c00da3

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

gateloop_transformer/simplified_gate_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
prenorm = True,
112112
use_heinsen = False,
113113
use_jax_associative_scan = False,
114+
post_ln = False,
114115
reverse = False
115116
):
116117
super().__init__()
@@ -135,6 +136,7 @@ def __init__(
135136
else:
136137
self.gate_loop_fn = gate_loop_operator
137138

139+
self.maybe_post_ln = nn.LayerNorm(dim) if post_ln else nn.Identity()
138140
self.split_heads = Rearrange('(b d) n 1 -> b n d', d = dim)
139141

140142
self.reverse = reverse
@@ -155,6 +157,7 @@ def forward(
155157
out, cache = self.gate_loop_fn(q, kv, a.sigmoid(), cache = cache)
156158

157159
out = self.split_heads(out)
160+
out = self.maybe_post_ln(out)
158161

159162
if self.reverse:
160163
out = torch.flip(out, dims = (-2,))

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'gateloop-transformer',
55
packages = find_packages(exclude=[]),
6-
version = '0.2.3',
6+
version = '0.2.4',
77
license='MIT',
88
description = 'GateLoop Transformer',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)