@@ -66,40 +66,30 @@ def token_shift(t):
6666 t_shift = F .pad (t_shift , (0 , 0 , 1 , - 1 ))
6767 return torch .cat ((t , t_shift ), dim = - 1 )
6868
69- # positional bias
69+ # rotary positional embedding
7070
71- class Alibi (nn .Module ):
72- def __init__ (self , heads , ** kwargs ):
71+ class RotaryEmbedding (nn .Module ):
72+ def __init__ (self , dim , theta = 10000 ):
7373 super ().__init__ ()
74- self .heads = heads
75- slopes = torch .Tensor (self ._get_slopes (heads ))
76- slopes = rearrange (slopes , 'h -> h 1 1' )
77- self .register_buffer ('slopes' , slopes , persistent = False )
78- self .register_buffer ('bias' , None , persistent = False )
79-
80- @staticmethod
81- def _get_slopes (heads ):
82- def get_slopes_power_of_2 (n ):
83- start = (2 ** (- 2 ** - (math .log2 (n )- 3 )))
84- ratio = start
85- return [start * ratio ** i for i in range (n )]
74+ inv_freq = 1.0 / (theta ** (torch .arange (0 , dim , 2 ).float () / dim ))
75+ self .register_buffer ("inv_freq" , inv_freq )
8676
87- if math .log2 (heads ).is_integer ():
88- return get_slopes_power_of_2 (heads )
77+ @property
78+ def device (self ):
79+ return next (self .buffers ()).device
8980
90- closest_power_of_2 = 2 ** math .floor (math .log2 (heads ))
91- return get_slopes_power_of_2 (closest_power_of_2 ) + get_slopes_power_of_2 (2 * closest_power_of_2 )[0 ::2 ][:heads - closest_power_of_2 ]
81+ def forward (self , seq_len ):
82+ t = torch .arange (seq_len , device = self .device ).type_as (self .inv_freq )
83+ freqs = torch .einsum ('i , j -> i j' , t , self .inv_freq )
84+ freqs = torch .cat ((freqs , freqs ), dim = - 1 )
85+ return freqs
9286
93- def forward ( self , i , j , device ):
94- if exists ( self . bias ) and self . bias . shape [ - 1 ] >= j :
95- return self . bias [..., : j ]
87+ def rotate_half ( x ):
88+ x1 , x2 = x . chunk ( 2 , dim = - 1 )
89+ return torch . cat (( - x2 , x1 ), dim = - 1 )
9690
97- bias = torch .arange (j , device = device )
98- bias = rearrange (bias , 'j -> 1 1 j' )
99- bias = bias * self .slopes
100-
101- self .register_buffer ('bias' , bias , persistent = False )
102- return self .bias
91+ def apply_rotary_pos_emb (pos , t ):
92+ return t * pos .cos () + rotate_half (t ) * pos .sin ()
10393
10494# norm
10595
@@ -152,14 +142,17 @@ def __init__(
152142 self .to_kv = nn .Linear (dim , dim_head * 2 , bias = False )
153143 self .to_out = nn .Linear (inner_dim , dim , bias = False )
154144
155- def forward (self , x , attn_bias = None ):
145+ def forward (self , x , rotary_emb = None ):
156146 h , device = self .heads , x .device
157147
158148 x = self .norm (x )
159149 q , k , v = (self .to_q (x ), * self .to_kv (x ).chunk (2 , dim = - 1 ))
160150 q = rearrange (q , 'b n (h d) -> b h n d' , h = h )
161151
162- out = self .attend (q , k , v , attn_bias = attn_bias )
152+ if exists (rotary_emb ):
153+ q , k = map (lambda t : apply_rotary_pos_emb (rotary_emb , t ), (q , k ))
154+
155+ out = self .attend (q , k , v )
163156
164157 out = rearrange (out , 'b h n d -> b n (h d)' )
165158 return self .to_out (out )
@@ -175,11 +168,11 @@ def __init__(
175168 attn_dropout = 0. ,
176169 ff_dropout = 0. ,
177170 ff_mult = 4 ,
178- rel_pos_bias = True ,
171+ rel_pos = True ,
179172 flash_attn = False
180173 ):
181174 super ().__init__ ()
182- self .alibi = Alibi ( heads = heads ) if rel_pos_bias else None
175+ self .rotary_emb = RotaryEmbedding ( dim_head ) if rel_pos else None
183176 self .layers = nn .ModuleList ([])
184177
185178 for _ in range (layers ):
@@ -192,10 +185,10 @@ def __init__(
192185
193186 def forward (self , x ):
194187 n = x .shape [- 2 ]
195- attn_bias = self .alibi ( n , n , device = x . device ) if exists (self .alibi ) else None
188+ rotary_emb = self .rotary_emb ( n ) if exists (self .rotary_emb ) else None
196189
197190 for attn , ff in self .layers :
198- x = attn (token_shift (x ), attn_bias = attn_bias ) + x
191+ x = attn (token_shift (x ), rotary_emb = rotary_emb ) + x
199192 x = ff (token_shift (x )) + x
200193
201194 return self .norm (x )
@@ -218,7 +211,7 @@ def __init__(
218211 ff_mult = 4 ,
219212 ff_dropout = 0. ,
220213 pad_id = 0 ,
221- rel_pos_bias = False ,
214+ rel_pos = False ,
222215 pos_emb = False ,
223216 flash_attn = False
224217 ):
@@ -264,7 +257,7 @@ def __init__(
264257 attn_dropout = attn_dropout ,
265258 ff_dropout = ff_dropout ,
266259 ff_mult = ff_mult ,
267- rel_pos_bias = rel_pos_bias ,
260+ rel_pos = rel_pos ,
268261 flash_attn = flash_attn
269262 ))
270263
0 commit comments