Skip to content

Commit 1ded1cb

Browse files
committed
address #4
1 parent 1c00da3 commit 1ded1cb

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

gateloop_transformer/simplified_gate_loop.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,26 @@
1717
def exists(v):
1818
return v is not None
1919

20-
def abs_clamp_eps(t, eps = 1e-20):
20+
def default(v, d):
21+
return v if exists(v) else d
22+
23+
def eps_by_dtype(dtype):
24+
return 1e-7 if dtype == torch.float16 else 1e-20
25+
26+
def abs_clamp_eps(t, eps = None):
27+
eps = default(eps, eps_by_dtype(t.dtype))
2128
sign = torch.sign(t)
2229
return sign * t.abs().clamp(min = eps)
2330

2431
# associative scan using heinsen sequences
2532
# https://github.com/glassroom/heinsen_sequence
2633
# graciously shared to the world by Franz A. Heinsen in https://arxiv.org/abs/2311.06281 in October 2023
2734

28-
def heinsen_associative_scan(a, kv, eps = 1e-20):
35+
def heinsen_associative_scan(a, kv, eps = None):
36+
eps = default(eps, eps_by_dtype(a.dtype))
2937
log_a = a.clamp(min = eps).log()
30-
log_kv = abs_clamp_eps(kv, eps = eps).to(dtype = torch.complex64).log()
38+
39+
log_kv = abs_clamp_eps(kv).to(dtype = torch.complex64).log()
3140

3241
a_star = torch.cumsum(log_a, dim = 1)
3342
log_x0_plus_b_star = torch.logcumsumexp(log_kv - a_star, dim = 1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'gateloop-transformer',
55
packages = find_packages(exclude=[]),
6-
version = '0.2.4',
6+
version = '0.2.5',
77
license='MIT',
88
description = 'GateLoop Transformer',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)