@@ -238,10 +238,10 @@ def __init__(
238238 coarsest_dim , * _ , fine_dim = dim
239239
240240 self .token_emb = nn .Embedding (num_tokens , fine_dim )
241- self .start_tokens = nn .Parameter (torch .randn (coarsest_dim ))
242241
243242 self .max_seq_len = max_seq_len
244243
244+ self .start_tokens = nn .ParameterList ([nn .Parameter (torch .randn (h_dim )) for h_dim , seq_len in zip (dim , max_seq_len )])
245245 self .pos_embs = nn .ModuleList ([nn .Embedding (seq_len , h_dim ) for h_dim , seq_len in zip (dim , max_seq_len )])
246246
247247 self .patch_embedders = nn .ModuleList ([nn .Sequential (
@@ -254,7 +254,7 @@ def __init__(
254254 self .transformers = nn .ModuleList ([])
255255 self .to_next_transformer_projections = nn .ModuleList ([])
256256
257- for h_dim , next_h_dim , stage_depth in zip_longest (dim , dim [1 :], depth ):
257+ for h_dim , next_h_dim , stage_depth , next_seq_len in zip_longest (dim , dim [1 :], depth , max_seq_len [ 1 :] ):
258258 self .transformers .append (Transformer (
259259 dim = h_dim ,
260260 layers = stage_depth ,
@@ -267,7 +267,14 @@ def __init__(
267267 flash_attn = flash_attn
268268 ))
269269
270- proj = nn .Linear (h_dim , next_h_dim ) if exists (next_h_dim ) and next_h_dim != dim else nn .Identity ()
270+ proj = nn .Identity ()
271+
272+ if exists (next_h_dim ) and next_h_dim != dim :
273+ proj = nn .Sequential (
274+ nn .Linear (h_dim , next_h_dim * (next_seq_len + 1 )),
275+ Rearrange ('... (n d) -> (...) n d' , n = next_seq_len + 1 )
276+ )
277+
271278 self .to_next_transformer_projections .append (proj )
272279
273280 self .to_logits = nn .Linear (fine_dim , num_tokens )
@@ -295,10 +302,16 @@ def forward_empty(self, batch_size):
295302 # take care of special case
296303 # where you sample from input of 0 (start token only)
297304
298- tokens = repeat (self .start_tokens , 'd -> b 1 d' , b = batch_size )
305+ prev_stage_tokens_repr = None
306+
307+ for stage_start_tokens , transformer , proj in zip (self .start_tokens , self .transformers , self .to_next_transformer_projections ):
308+ tokens = repeat (stage_start_tokens , 'd -> b 1 d' , b = batch_size )
309+
310+ if exists (prev_stage_tokens_repr ):
311+ tokens = tokens + prev_stage_tokens_repr [..., :tokens .shape [- 2 ], :]
299312
300- for transformer in self .transformers :
301313 tokens = transformer (tokens )
314+ prev_stage_tokens_repr = proj (tokens )
302315
303316 return self .to_logits (tokens )
304317
@@ -348,28 +361,38 @@ def forward(self, ids, return_loss = False):
348361 tokens_with_position = reduced_tokens + positions
349362 tokens_at_stages .insert (0 , tokens_with_position )
350363
351- # get start tokens and append to the coarsest stage
364+ # the un-pixelshuffled representations of the previous hierarchy, starts with None
352365
353- start_tokens = repeat ( self . start_tokens , 'f -> b 1 f' , b = b )
366+ prev_stage_tokens_repr = None
354367
355368 # spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions
356369
357- for ind , (stage_tokens , transformer , proj ) in enumerate (zip (tokens_at_stages , self .transformers , self .to_next_transformer_projections )):
358- is_last = ind == (self .stages - 1 )
370+ for stage_start_tokens , stage_tokens , transformer , proj in zip (self .start_tokens , tokens_at_stages , self .transformers , self .to_next_transformer_projections ):
371+ stage_tokens , ps = pack_one (stage_tokens , '* n d' )
372+
373+ stage_start_tokens = repeat (stage_start_tokens , 'f -> b 1 f' , b = stage_tokens .shape [0 ])
374+
375+ # concat start token
359376
360377 stage_tokens = torch .cat ((
361- start_tokens ,
378+ stage_start_tokens ,
362379 stage_tokens ,
363380 ), dim = - 2 )
364381
365- stage_tokens , ps = pack_one (stage_tokens , '* n d' )
382+ # sum the previous hierarchy's representation
383+
384+ if exists (prev_stage_tokens_repr ):
385+ stage_tokens = stage_tokens + prev_stage_tokens_repr [..., :stage_tokens .shape [- 2 ], :]
366386
367387 attended = transformer (stage_tokens )
368- attended = proj (attended )
369388
370389 attended = unpack_one (attended , ps , '* n d' )
371390
372- start_tokens = rearrange (attended [..., :- 1 , :], '... n d -> ... n 1 d' )
391+ # project for next stage in the hierarchy
392+
393+ prev_stage_tokens_repr = proj (attended [..., :- 1 , :])
394+
395+ # project to logits
373396
374397 logits = self .to_logits (attended )
375398
0 commit comments