Skip to content

Commit 810d77a

Browse files
committed
switch to rotary embeddings, as they did in the paper
1 parent 3eafa96 commit 810d77a

File tree

4 files changed

+44
-69
lines changed

4 files changed

+44
-69
lines changed

MEGABYTE_pytorch/attend.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -89,34 +89,19 @@ def flash_attn(self, q, k, v, mask = None, attn_bias = None):
8989

9090
config = self.cuda_config if is_cuda else self.cpu_config
9191

92-
causal = self.causal
93-
94-
# handle attention bias
95-
96-
if exists(attn_bias):
97-
mask_value = -torch.finfo(q.dtype).max // 2
98-
causal_mask = self.get_mask(q_len, k_len, device)
99-
attn_bias = attn_bias.masked_fill(causal_mask, mask_value)
100-
101-
if exists(mask):
102-
attn_bias = attn_bias.masked_fill(~mask, mask_value)
103-
104-
mask = attn_bias
105-
causal = False
106-
10792
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
10893

10994
with torch.backends.cuda.sdp_kernel(**config._asdict()):
11095
out = F.scaled_dot_product_attention(
11196
q, k, v,
11297
attn_mask = mask,
11398
dropout_p = self.dropout if self.training else 0.,
114-
is_causal = causal
99+
is_causal = self.causal
115100
)
116101

117102
return out
118103

119-
def forward(self, q, k, v, mask = None, attn_bias = None):
104+
def forward(self, q, k, v, mask = None):
120105
"""
121106
einstein notation
122107
b - batch
@@ -132,17 +117,12 @@ def forward(self, q, k, v, mask = None, attn_bias = None):
132117
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
133118

134119
if self.flash:
135-
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
120+
return self.flash_attn(q, k, v, mask = mask)
136121

137122
# similarity
138123

139124
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
140125

141-
# attention bias
142-
143-
if exists(attn_bias):
144-
sim = sim + attn_bias
145-
146126
# causal mask
147127

148128
if self.causal:

MEGABYTE_pytorch/megabyte.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -66,40 +66,30 @@ def token_shift(t):
6666
t_shift = F.pad(t_shift, (0, 0, 1, -1))
6767
return torch.cat((t, t_shift), dim = -1)
6868

69-
# positional bias
69+
# rotary positional embedding
7070

71-
class Alibi(nn.Module):
72-
def __init__(self, heads, **kwargs):
71+
class RotaryEmbedding(nn.Module):
72+
def __init__(self, dim, theta = 10000):
7373
super().__init__()
74-
self.heads = heads
75-
slopes = torch.Tensor(self._get_slopes(heads))
76-
slopes = rearrange(slopes, 'h -> h 1 1')
77-
self.register_buffer('slopes', slopes, persistent = False)
78-
self.register_buffer('bias', None, persistent = False)
79-
80-
@staticmethod
81-
def _get_slopes(heads):
82-
def get_slopes_power_of_2(n):
83-
start = (2**(-2**-(math.log2(n)-3)))
84-
ratio = start
85-
return [start*ratio**i for i in range(n)]
74+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
75+
self.register_buffer("inv_freq", inv_freq)
8676

87-
if math.log2(heads).is_integer():
88-
return get_slopes_power_of_2(heads)
77+
@property
78+
def device(self):
79+
return next(self.buffers()).device
8980

90-
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
91-
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
81+
def forward(self, seq_len):
82+
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
83+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
84+
freqs = torch.cat((freqs, freqs), dim = -1)
85+
return freqs
9286

93-
def forward(self, i, j, device):
94-
if exists(self.bias) and self.bias.shape[-1] >= j:
95-
return self.bias[..., :j]
87+
def rotate_half(x):
88+
x1, x2 = x.chunk(2, dim=-1)
89+
return torch.cat((-x2, x1), dim=-1)
9690

97-
bias = torch.arange(j, device = device)
98-
bias = rearrange(bias, 'j -> 1 1 j')
99-
bias = bias * self.slopes
100-
101-
self.register_buffer('bias', bias, persistent = False)
102-
return self.bias
91+
def apply_rotary_pos_emb(pos, t):
92+
return t * pos.cos() + rotate_half(t) * pos.sin()
10393

10494
# norm
10595

@@ -152,14 +142,17 @@ def __init__(
152142
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
153143
self.to_out = nn.Linear(inner_dim, dim, bias = False)
154144

155-
def forward(self, x, attn_bias = None):
145+
def forward(self, x, rotary_emb = None):
156146
h, device = self.heads, x.device
157147

158148
x = self.norm(x)
159149
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
160150
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
161151

162-
out = self.attend(q, k, v, attn_bias = attn_bias)
152+
if exists(rotary_emb):
153+
q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))
154+
155+
out = self.attend(q, k, v)
163156

164157
out = rearrange(out, 'b h n d -> b n (h d)')
165158
return self.to_out(out)
@@ -175,11 +168,11 @@ def __init__(
175168
attn_dropout = 0.,
176169
ff_dropout = 0.,
177170
ff_mult = 4,
178-
rel_pos_bias = True,
171+
rel_pos = True,
179172
flash_attn = False
180173
):
181174
super().__init__()
182-
self.alibi = Alibi(heads = heads) if rel_pos_bias else None
175+
self.rotary_emb = RotaryEmbedding(dim_head) if rel_pos else None
183176
self.layers = nn.ModuleList([])
184177

185178
for _ in range(layers):
@@ -192,10 +185,10 @@ def __init__(
192185

193186
def forward(self, x):
194187
n = x.shape[-2]
195-
attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None
188+
rotary_emb = self.rotary_emb(n) if exists(self.rotary_emb) else None
196189

197190
for attn, ff in self.layers:
198-
x = attn(token_shift(x), attn_bias = attn_bias) + x
191+
x = attn(token_shift(x), rotary_emb = rotary_emb) + x
199192
x = ff(token_shift(x)) + x
200193

201194
return self.norm(x)
@@ -218,7 +211,7 @@ def __init__(
218211
ff_mult = 4,
219212
ff_dropout = 0.,
220213
pad_id = 0,
221-
rel_pos_bias = False,
214+
rel_pos = False,
222215
pos_emb = False,
223216
flash_attn = False
224217
):
@@ -264,7 +257,7 @@ def __init__(
264257
attn_dropout = attn_dropout,
265258
ff_dropout = ff_dropout,
266259
ff_mult = ff_mult,
267-
rel_pos_bias = rel_pos_bias,
260+
rel_pos = rel_pos,
268261
flash_attn = flash_attn
269262
))
270263

README.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,6 @@ $ python train.py
8989
}
9090
```
9191

92-
```bibtex
93-
@misc{press2021ALiBi,
94-
title = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
95-
author = {Ofir Press and Noah A. Smith and Mike Lewis},
96-
year = {2021},
97-
url = {https://ofir.io/train_short_test_long.pdf}
98-
}
99-
```
100-
10192
```bibtex
10293
@software{peng_bo_2021_5196578,
10394
author = {PENG Bo},
@@ -120,3 +111,14 @@ $ python train.py
120111
volume = {abs/2305.19466}
121112
}
122113
```
114+
115+
```bibtex
116+
@misc{su2021roformer,
117+
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
118+
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
119+
year = {2021},
120+
eprint = {2104.09864},
121+
archivePrefix = {arXiv},
122+
primaryClass = {cs.CL}
123+
}
124+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.1.6',
6+
version = '0.1.7',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)