44
55from typing import Tuple
66
7- from einops import rearrange
7+ from einops import rearrange , pack , unpack
88from einops .layers .torch import Rearrange
99
1010from gateloop_transformer .gateloop_transformer import RMSNorm
1111from gateloop_transformer .associative_scan import associative_scan
1212
1313# plain pytorch non-fused associative scan
1414
15- def gate_loop_operator (q , kv , a ):
15+ def exists (v ):
16+ return v is not None
17+
18+ def gate_loop_operator (q , kv , a , cache = None ):
1619
1720 @torch .jit .script
1821 def binary_operator (
@@ -23,9 +26,18 @@ def binary_operator(
2326 a_j , kv_j = b
2427 return a_j * a_i , torch .addcmul (kv_j , a_j , kv_i )
2528
26- _ , kv = associative_scan (binary_operator , (a , kv ))
29+ if exists (cache ):
30+ cache_a , cache_kv = cache
31+ a , a_ps = pack ([cache_a , a ], 'b * d' )
32+ kv , kv_ps = pack ([cache_kv , kv ], 'b * d' )
33+
34+ a , kv = associative_scan (binary_operator , (a , kv ))
35+
36+ if exists (cache ):
37+ _ , a = unpack (a , a_ps , 'b * d' )
38+ _ , kv = unpack (kv , kv_ps , 'b * d' )
2739
28- return q * kv
40+ return q * kv , ( a [:, - 1 ], kv [:, - 1 ])
2941
3042# using jax associative scan
3143
@@ -48,7 +60,7 @@ def binary_operator(e_i, e_j):
4860
4961 return q * y
5062
51- return jax2torch (jax_gate_loop_operator )
63+ return jax2torch (jax_gate_loop_operator ), None
5264
5365# simple gate loop layer
5466
@@ -75,6 +87,8 @@ def __init__(
7587 Rearrange ('b n (qkva d) -> qkva (b d) n 1' , qkva = 3 )
7688 )
7789
90+ self .use_jax = use_jax_associative_scan
91+
7892 if use_jax_associative_scan :
7993 self .gate_loop_fn = get_jax_gate_loop_operator ()
8094 else :
@@ -84,20 +98,30 @@ def __init__(
8498
8599 self .reverse = reverse
86100
87- def forward (self , x ):
88-
101+ def forward (
102+ self ,
103+ x ,
104+ cache = None ,
105+ return_cache = False
106+ ):
89107 if self .reverse :
90108 x = torch .flip (x , dims = (- 2 ,))
91109
92110 x = self .norm (x )
93111
94112 q , kv , a = self .to_qkva (x )
95113
96- out = self .gate_loop_fn (q , kv , a .sigmoid ())
114+ out , cache = self .gate_loop_fn (q , kv , a .sigmoid (), cache = cache )
97115
98116 out = self .split_heads (out )
99117
100118 if self .reverse :
101119 out = torch .flip (out , dims = (- 2 ,))
102120
103- return out
121+ if not return_cache :
122+ return out
123+
124+ assert not self .reverse , 'caching only works with non-reversed seq'
125+ assert not self .use_jax , 'jax associative scan does not have caching yet'
126+
127+ return out , cache
0 commit comments