File tree Expand file tree Collapse file tree 2 files changed +13
-4
lines changed Expand file tree Collapse file tree 2 files changed +13
-4
lines changed Original file line number Diff line number Diff line change 1717def exists (v ):
1818 return v is not None
1919
20- def abs_clamp_eps (t , eps = 1e-20 ):
20+ def default (v , d ):
21+ return v if exists (v ) else d
22+
23+ def eps_by_dtype (dtype ):
24+ return 1e-7 if dtype == torch .float16 else 1e-20
25+
26+ def abs_clamp_eps (t , eps = None ):
27+ eps = default (eps , eps_by_dtype (t .dtype ))
2128 sign = torch .sign (t )
2229 return sign * t .abs ().clamp (min = eps )
2330
2431# associative scan using heinsen sequences
2532# https://github.com/glassroom/heinsen_sequence
2633# graciously shared to the world by Franz A. Heinsen in https://arxiv.org/abs/2311.06281 in October 2023
2734
28- def heinsen_associative_scan (a , kv , eps = 1e-20 ):
35+ def heinsen_associative_scan (a , kv , eps = None ):
36+ eps = default (eps , eps_by_dtype (a .dtype ))
2937 log_a = a .clamp (min = eps ).log ()
30- log_kv = abs_clamp_eps (kv , eps = eps ).to (dtype = torch .complex64 ).log ()
38+
39+ log_kv = abs_clamp_eps (kv ).to (dtype = torch .complex64 ).log ()
3140
3241 a_star = torch .cumsum (log_a , dim = 1 )
3342 log_x0_plus_b_star = torch .logcumsumexp (log_kv - a_star , dim = 1 )
Original file line number Diff line number Diff line change 33setup (
44 name = 'gateloop-transformer' ,
55 packages = find_packages (exclude = []),
6- version = '0.2.4 ' ,
6+ version = '0.2.5 ' ,
77 license = 'MIT' ,
88 description = 'GateLoop Transformer' ,
99 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments