Skip to content

Commit ede6166

Browse files
committed
the paper did multiple start tokens and also had last hierarchy project and unpixel shuffle a patch which they summed to the entire next patch
1 parent cdfa143 commit ede6166

File tree

1 file changed

+40
-15
lines changed

1 file changed

+40
-15
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@ def __init__(
238238
coarsest_dim, *_, fine_dim = dim
239239

240240
self.token_emb = nn.Embedding(num_tokens, fine_dim)
241-
self.start_tokens = nn.Parameter(torch.randn(coarsest_dim))
242241

243242
self.max_seq_len = max_seq_len
244243

244+
self.start_tokens = nn.ParameterList([nn.Parameter(torch.randn(h_dim)) for h_dim, seq_len in zip(dim, max_seq_len)])
245245
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)])
246246

247247
self.patch_embedders = nn.ModuleList([nn.Sequential(
@@ -254,7 +254,7 @@ def __init__(
254254
self.transformers = nn.ModuleList([])
255255
self.to_next_transformer_projections = nn.ModuleList([])
256256

257-
for h_dim, next_h_dim, stage_depth in zip_longest(dim, dim[1:], depth):
257+
for h_dim, next_h_dim, stage_depth, next_seq_len in zip_longest(dim, dim[1:], depth, max_seq_len[1:]):
258258
self.transformers.append(Transformer(
259259
dim = h_dim,
260260
layers = stage_depth,
@@ -267,7 +267,14 @@ def __init__(
267267
flash_attn = flash_attn
268268
))
269269

270-
proj = nn.Linear(h_dim, next_h_dim) if exists(next_h_dim) and next_h_dim != dim else nn.Identity()
270+
proj = nn.Identity()
271+
272+
if exists(next_h_dim) and next_h_dim != dim:
273+
proj = nn.Sequential(
274+
nn.Linear(h_dim, next_h_dim * next_seq_len),
275+
Rearrange('... (n d) -> (...) n d', n = next_seq_len)
276+
)
277+
271278
self.to_next_transformer_projections.append(proj)
272279

273280
self.to_logits = nn.Linear(fine_dim, num_tokens)
@@ -295,10 +302,16 @@ def forward_empty(self, batch_size):
295302
# take care of special case
296303
# where you sample from input of 0 (start token only)
297304

298-
tokens = repeat(self.start_tokens, 'd -> b 1 d', b = batch_size)
305+
prev_stage_tokens_repr = None
306+
307+
for stage_start_tokens, transformer, proj in zip(self.start_tokens, self.transformers, self.to_next_transformer_projections):
308+
tokens = repeat(stage_start_tokens, 'd -> b 1 d', b = batch_size)
309+
310+
if exists(prev_stage_tokens_repr):
311+
tokens = tokens + prev_stage_tokens_repr[..., :tokens.shape[-2], :]
299312

300-
for transformer in self.transformers:
301313
tokens = transformer(tokens)
314+
prev_stage_tokens_repr = proj(tokens)
302315

303316
return self.to_logits(tokens)
304317

@@ -348,32 +361,44 @@ def forward(self, ids, return_loss = False):
348361
tokens_with_position = reduced_tokens + positions
349362
tokens_at_stages.insert(0, tokens_with_position)
350363

351-
# get start tokens and append to the coarsest stage
364+
# the un-pixelshuffled representations of the previous hierarchy, starts with None
352365

353-
start_tokens = repeat(self.start_tokens, 'f -> b 1 f', b = b)
366+
prev_stage_tokens_repr = None
354367

355368
# spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions
356369

357-
for ind, (stage_tokens, transformer, proj) in enumerate(zip(tokens_at_stages, self.transformers, self.to_next_transformer_projections)):
358-
is_last = ind == (self.stages - 1)
370+
for stage_start_tokens, stage_tokens, transformer, proj in zip(self.start_tokens, tokens_at_stages, self.transformers, self.to_next_transformer_projections):
371+
stage_tokens, ps = pack_one(stage_tokens, '* n d')
372+
373+
stage_start_tokens = repeat(stage_start_tokens, 'f -> b 1 f', b = stage_tokens.shape[0])
374+
375+
# concat start token
359376

360377
stage_tokens = torch.cat((
361-
start_tokens,
378+
stage_start_tokens,
362379
stage_tokens,
363380
), dim = -2)
364381

365-
stage_tokens, ps = pack_one(stage_tokens, '* n d')
382+
# omit last token
383+
384+
stage_tokens = stage_tokens[:, :-1]
385+
386+
# sum the previous hierarchy's representation
387+
388+
if exists(prev_stage_tokens_repr):
389+
stage_tokens = stage_tokens + prev_stage_tokens_repr[..., :stage_tokens.shape[-2], :]
366390

367391
attended = transformer(stage_tokens)
368-
attended = proj(attended)
369392

370393
attended = unpack_one(attended, ps, '* n d')
371394

372-
start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d')
395+
# project for next stage in the hierarchy
396+
397+
prev_stage_tokens_repr = proj(attended)
373398

374-
logits = self.to_logits(attended)
399+
# project to logits
375400

376-
logits = logits[..., 1:, :]
401+
logits = self.to_logits(attended)
377402

378403
if not return_loss:
379404

0 commit comments

Comments
 (0)