Skip to content

Commit fe82b1d

Browse files
authored
Merge pull request #8 from lucidrains/paper-variation
Paper variation
2 parents cdfa143 + 0277167 commit fe82b1d

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@ def __init__(
238238
coarsest_dim, *_, fine_dim = dim
239239

240240
self.token_emb = nn.Embedding(num_tokens, fine_dim)
241-
self.start_tokens = nn.Parameter(torch.randn(coarsest_dim))
242241

243242
self.max_seq_len = max_seq_len
244243

244+
self.start_tokens = nn.ParameterList([nn.Parameter(torch.randn(h_dim)) for h_dim, seq_len in zip(dim, max_seq_len)])
245245
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)])
246246

247247
self.patch_embedders = nn.ModuleList([nn.Sequential(
@@ -254,7 +254,7 @@ def __init__(
254254
self.transformers = nn.ModuleList([])
255255
self.to_next_transformer_projections = nn.ModuleList([])
256256

257-
for h_dim, next_h_dim, stage_depth in zip_longest(dim, dim[1:], depth):
257+
for h_dim, next_h_dim, stage_depth, next_seq_len in zip_longest(dim, dim[1:], depth, max_seq_len[1:]):
258258
self.transformers.append(Transformer(
259259
dim = h_dim,
260260
layers = stage_depth,
@@ -267,7 +267,14 @@ def __init__(
267267
flash_attn = flash_attn
268268
))
269269

270-
proj = nn.Linear(h_dim, next_h_dim) if exists(next_h_dim) and next_h_dim != dim else nn.Identity()
270+
proj = nn.Identity()
271+
272+
if exists(next_h_dim) and next_h_dim != dim:
273+
proj = nn.Sequential(
274+
nn.Linear(h_dim, next_h_dim * (next_seq_len + 1)),
275+
Rearrange('... (n d) -> (...) n d', n = next_seq_len + 1)
276+
)
277+
271278
self.to_next_transformer_projections.append(proj)
272279

273280
self.to_logits = nn.Linear(fine_dim, num_tokens)
@@ -295,10 +302,16 @@ def forward_empty(self, batch_size):
295302
# take care of special case
296303
# where you sample from input of 0 (start token only)
297304

298-
tokens = repeat(self.start_tokens, 'd -> b 1 d', b = batch_size)
305+
prev_stage_tokens_repr = None
306+
307+
for stage_start_tokens, transformer, proj in zip(self.start_tokens, self.transformers, self.to_next_transformer_projections):
308+
tokens = repeat(stage_start_tokens, 'd -> b 1 d', b = batch_size)
309+
310+
if exists(prev_stage_tokens_repr):
311+
tokens = tokens + prev_stage_tokens_repr[..., :tokens.shape[-2], :]
299312

300-
for transformer in self.transformers:
301313
tokens = transformer(tokens)
314+
prev_stage_tokens_repr = proj(tokens)
302315

303316
return self.to_logits(tokens)
304317

@@ -348,28 +361,38 @@ def forward(self, ids, return_loss = False):
348361
tokens_with_position = reduced_tokens + positions
349362
tokens_at_stages.insert(0, tokens_with_position)
350363

351-
# get start tokens and append to the coarsest stage
364+
# the un-pixelshuffled representations of the previous hierarchy, starts with None
352365

353-
start_tokens = repeat(self.start_tokens, 'f -> b 1 f', b = b)
366+
prev_stage_tokens_repr = None
354367

355368
# spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions
356369

357-
for ind, (stage_tokens, transformer, proj) in enumerate(zip(tokens_at_stages, self.transformers, self.to_next_transformer_projections)):
358-
is_last = ind == (self.stages - 1)
370+
for stage_start_tokens, stage_tokens, transformer, proj in zip(self.start_tokens, tokens_at_stages, self.transformers, self.to_next_transformer_projections):
371+
stage_tokens, ps = pack_one(stage_tokens, '* n d')
372+
373+
stage_start_tokens = repeat(stage_start_tokens, 'f -> b 1 f', b = stage_tokens.shape[0])
374+
375+
# concat start token
359376

360377
stage_tokens = torch.cat((
361-
start_tokens,
378+
stage_start_tokens,
362379
stage_tokens,
363380
), dim = -2)
364381

365-
stage_tokens, ps = pack_one(stage_tokens, '* n d')
382+
# sum the previous hierarchy's representation
383+
384+
if exists(prev_stage_tokens_repr):
385+
stage_tokens = stage_tokens + prev_stage_tokens_repr[..., :stage_tokens.shape[-2], :]
366386

367387
attended = transformer(stage_tokens)
368-
attended = proj(attended)
369388

370389
attended = unpack_one(attended, ps, '* n d')
371390

372-
start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d')
391+
# project for next stage in the hierarchy
392+
393+
prev_stage_tokens_repr = proj(attended[..., :-1, :])
394+
395+
# project to logits
373396

374397
logits = self.to_logits(attended)
375398

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.0.10',
6+
version = '0.1.0',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)