Skip to content

Commit 4066a0c

Browse files
committed
move closer to what the paper did, with local and global token embeddings not shared
1 parent 810d77a commit 4066a0c

File tree

2 files changed

+24
-25
lines changed

2 files changed

+24
-25
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -231,19 +231,22 @@ def __init__(
231231

232232
coarsest_dim, *_, fine_dim = dim
233233

234-
self.token_emb = nn.Embedding(num_tokens, fine_dim)
235-
236234
self.max_seq_len = max_seq_len
237235

238236
self.start_tokens = nn.ParameterList([nn.Parameter(torch.randn(h_dim)) for h_dim, seq_len in zip(dim, max_seq_len)])
239237
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)]) if pos_emb else None
240238

241-
self.patch_embedders = nn.ModuleList([nn.Sequential(
242-
Rearrange('... r d -> ... (r d)'),
243-
nn.LayerNorm(seq_len * dim_in),
244-
nn.Linear(seq_len * dim_in, dim_out),
245-
nn.LayerNorm(dim_out)
246-
) for dim_in, dim_out, seq_len in zip(dim[1:], dim[:-1], max_seq_len[1:])])
239+
self.token_embs = nn.ModuleList([])
240+
self.token_embs.append(nn.Embedding(num_tokens, fine_dim))
241+
242+
for dim_out, seq_len in zip(dim[:-1], max_seq_len[1:]):
243+
self.token_embs.append(nn.Sequential(
244+
nn.Embedding(num_tokens, fine_dim),
245+
Rearrange('... r d -> ... (r d)'),
246+
nn.LayerNorm(seq_len * fine_dim),
247+
nn.Linear(seq_len * fine_dim, dim_out),
248+
nn.LayerNorm(dim_out)
249+
))
247250

248251
self.transformers = nn.ModuleList([])
249252
self.to_next_transformer_projections = nn.ModuleList([])
@@ -266,7 +269,7 @@ def __init__(
266269
if exists(next_h_dim) and next_h_dim != dim:
267270
proj = nn.Sequential(
268271
nn.Linear(h_dim, next_h_dim * next_seq_len),
269-
Rearrange('... (n d) -> (...) n d', n = next_seq_len)
272+
Rearrange('b m (n d) -> b (m n) d', n = next_seq_len)
270273
)
271274

272275
self.to_next_transformer_projections.append(proj)
@@ -335,29 +338,26 @@ def forward(self, ids, return_loss = False):
335338
assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
336339
assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly'
337340

338-
# get token embeddings
339-
340-
tokens = self.token_emb(ids)
341-
342341
# get tokens for all hierarchical stages, reducing by appropriate dimensions
343342
# and adding the absolute positional embeddings
344343

345344
tokens_at_stages = []
346-
reduced_tokens = tokens
347-
348345
pos_embs = default(self.pos_embs, (None,))
349346

350-
for ind, pos_emb, patch_emb in zip_longest(range(len(prec_dims)), reversed(pos_embs), reversed((*self.patch_embedders, None))):
351-
is_first = ind == 0
352-
353-
if not is_first:
354-
reduced_tokens = patch_emb(reduced_tokens)
347+
for ind, pos_emb, token_emb in zip_longest(range(len(prec_dims)), reversed(pos_embs), reversed(self.token_embs)):
348+
is_last = ind == (len(prec_dims) - 1)
349+
tokens = token_emb(ids)
355350

356351
if exists(pos_emb):
357-
positions = pos_emb(torch.arange(reduced_tokens.shape[-2], device = device))
358-
reduced_tokens = reduced_tokens + positions
352+
positions = pos_emb(torch.arange(tokens.shape[-2], device = device))
353+
tokens = tokens + positions
354+
355+
tokens_at_stages.append(tokens)
359356

360-
tokens_at_stages.insert(0, reduced_tokens)
357+
if is_last:
358+
continue
359+
360+
ids = rearrange(ids, '... m n -> ... (m n)')
361361

362362
# the un-pixelshuffled representations of the previous hierarchy, starts with None
363363

@@ -367,7 +367,6 @@ def forward(self, ids, return_loss = False):
367367

368368
for stage_start_tokens, stage_tokens, transformer, proj in zip(self.start_tokens, tokens_at_stages, self.transformers, self.to_next_transformer_projections):
369369
stage_tokens, ps = pack_one(stage_tokens, '* n d')
370-
371370
stage_start_tokens = repeat(stage_start_tokens, 'f -> b 1 f', b = stage_tokens.shape[0])
372371

373372
# concat start token

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.1.7',
6+
version = '0.2.0',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)