Skip to content

Commit 8864b69

Browse files
committed
make sure it supports greater than 2 hierarchies
1 parent 4066a0c commit 8864b69

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,18 @@ def __init__(
237237
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)]) if pos_emb else None
238238

239239
self.token_embs = nn.ModuleList([])
240+
241+
patch_size = 1
240242
self.token_embs.append(nn.Embedding(num_tokens, fine_dim))
241243

242-
for dim_out, seq_len in zip(dim[:-1], max_seq_len[1:]):
244+
for dim_out, seq_len in zip(reversed(dim[:-1]), reversed(max_seq_len[1:])):
245+
patch_size *= seq_len
246+
243247
self.token_embs.append(nn.Sequential(
244248
nn.Embedding(num_tokens, fine_dim),
245249
Rearrange('... r d -> ... (r d)'),
246-
nn.LayerNorm(seq_len * fine_dim),
247-
nn.Linear(seq_len * fine_dim, dim_out),
250+
nn.LayerNorm(patch_size * fine_dim),
251+
nn.Linear(patch_size * fine_dim, dim_out),
248252
nn.LayerNorm(dim_out)
249253
))
250254

@@ -268,8 +272,9 @@ def __init__(
268272

269273
if exists(next_h_dim) and next_h_dim != dim:
270274
proj = nn.Sequential(
275+
Rearrange('b ... d -> b (...) d'),
271276
nn.Linear(h_dim, next_h_dim * next_seq_len),
272-
Rearrange('b m (n d) -> b (m n) d', n = next_seq_len)
277+
Rearrange('b m (n d) -> (b m) n d', n = next_seq_len)
273278
)
274279

275280
self.to_next_transformer_projections.append(proj)
@@ -344,17 +349,18 @@ def forward(self, ids, return_loss = False):
344349
tokens_at_stages = []
345350
pos_embs = default(self.pos_embs, (None,))
346351

347-
for ind, pos_emb, token_emb in zip_longest(range(len(prec_dims)), reversed(pos_embs), reversed(self.token_embs)):
348-
is_last = ind == (len(prec_dims) - 1)
352+
for ind, pos_emb, token_emb in zip_longest(range(len(prec_dims)), pos_embs, self.token_embs):
353+
is_first = ind == 0
354+
349355
tokens = token_emb(ids)
350356

351357
if exists(pos_emb):
352358
positions = pos_emb(torch.arange(tokens.shape[-2], device = device))
353359
tokens = tokens + positions
354360

355-
tokens_at_stages.append(tokens)
361+
tokens_at_stages.insert(0, tokens)
356362

357-
if is_last:
363+
if is_first:
358364
continue
359365

360366
ids = rearrange(ids, '... m n -> ... (m n)')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.2.0',
6+
version = '0.2.1',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)