@@ -231,19 +231,22 @@ def __init__(
231231
232232 coarsest_dim , * _ , fine_dim = dim
233233
234- self .token_emb = nn .Embedding (num_tokens , fine_dim )
235-
236234 self .max_seq_len = max_seq_len
237235
238236 self .start_tokens = nn .ParameterList ([nn .Parameter (torch .randn (h_dim )) for h_dim , seq_len in zip (dim , max_seq_len )])
239237 self .pos_embs = nn .ModuleList ([nn .Embedding (seq_len , h_dim ) for h_dim , seq_len in zip (dim , max_seq_len )]) if pos_emb else None
240238
241- self .patch_embedders = nn .ModuleList ([nn .Sequential (
242- Rearrange ('... r d -> ... (r d)' ),
243- nn .LayerNorm (seq_len * dim_in ),
244- nn .Linear (seq_len * dim_in , dim_out ),
245- nn .LayerNorm (dim_out )
246- ) for dim_in , dim_out , seq_len in zip (dim [1 :], dim [:- 1 ], max_seq_len [1 :])])
239+ self .token_embs = nn .ModuleList ([])
240+ self .token_embs .append (nn .Embedding (num_tokens , fine_dim ))
241+
242+ for dim_out , seq_len in zip (dim [:- 1 ], max_seq_len [1 :]):
243+ self .token_embs .append (nn .Sequential (
244+ nn .Embedding (num_tokens , fine_dim ),
245+ Rearrange ('... r d -> ... (r d)' ),
246+ nn .LayerNorm (seq_len * fine_dim ),
247+ nn .Linear (seq_len * fine_dim , dim_out ),
248+ nn .LayerNorm (dim_out )
249+ ))
247250
248251 self .transformers = nn .ModuleList ([])
249252 self .to_next_transformer_projections = nn .ModuleList ([])
@@ -266,7 +269,7 @@ def __init__(
266269 if exists (next_h_dim ) and next_h_dim != dim :
267270 proj = nn .Sequential (
268271 nn .Linear (h_dim , next_h_dim * next_seq_len ),
269- Rearrange ('... (n d) -> (...) n d' , n = next_seq_len )
272+ Rearrange ('b m (n d) -> b (m n) d' , n = next_seq_len )
270273 )
271274
272275 self .to_next_transformer_projections .append (proj )
@@ -335,29 +338,26 @@ def forward(self, ids, return_loss = False):
335338 assert prec_dims [0 ] <= self .max_seq_len [0 ], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
336339 assert tuple (prec_dims [1 :]) == tuple (self .max_seq_len [1 :]), 'all subsequent dimensions must match exactly'
337340
338- # get token embeddings
339-
340- tokens = self .token_emb (ids )
341-
342341 # get tokens for all hierarchical stages, reducing by appropriate dimensions
343342 # and adding the absolute positional embeddings
344343
345344 tokens_at_stages = []
346- reduced_tokens = tokens
347-
348345 pos_embs = default (self .pos_embs , (None ,))
349346
350- for ind , pos_emb , patch_emb in zip_longest (range (len (prec_dims )), reversed (pos_embs ), reversed ((* self .patch_embedders , None ))):
351- is_first = ind == 0
352-
353- if not is_first :
354- reduced_tokens = patch_emb (reduced_tokens )
347+ for ind , pos_emb , token_emb in zip_longest (range (len (prec_dims )), reversed (pos_embs ), reversed (self .token_embs )):
348+ is_last = ind == (len (prec_dims ) - 1 )
349+ tokens = token_emb (ids )
355350
356351 if exists (pos_emb ):
357- positions = pos_emb (torch .arange (reduced_tokens .shape [- 2 ], device = device ))
358- reduced_tokens = reduced_tokens + positions
352+ positions = pos_emb (torch .arange (tokens .shape [- 2 ], device = device ))
353+ tokens = tokens + positions
354+
355+ tokens_at_stages .append (tokens )
359356
360- tokens_at_stages .insert (0 , reduced_tokens )
357+ if is_last :
358+ continue
359+
360+ ids = rearrange (ids , '... m n -> ... (m n)' )
361361
362362 # the un-pixelshuffled representations of the previous hierarchy, starts with None
363363
@@ -367,7 +367,6 @@ def forward(self, ids, return_loss = False):
367367
368368 for stage_start_tokens , stage_tokens , transformer , proj in zip (self .start_tokens , tokens_at_stages , self .transformers , self .to_next_transformer_projections ):
369369 stage_tokens , ps = pack_one (stage_tokens , '* n d' )
370-
371370 stage_start_tokens = repeat (stage_start_tokens , 'f -> b 1 f' , b = stage_tokens .shape [0 ])
372371
373372 # concat start token
0 commit comments