@@ -237,14 +237,18 @@ def __init__(
237237 self .pos_embs = nn .ModuleList ([nn .Embedding (seq_len , h_dim ) for h_dim , seq_len in zip (dim , max_seq_len )]) if pos_emb else None
238238
239239 self .token_embs = nn .ModuleList ([])
240+
241+ patch_size = 1
240242 self .token_embs .append (nn .Embedding (num_tokens , fine_dim ))
241243
242- for dim_out , seq_len in zip (dim [:- 1 ], max_seq_len [1 :]):
244+ for dim_out , seq_len in zip (reversed (dim [:- 1 ]), reversed (max_seq_len [1 :])):
245+ patch_size *= seq_len
246+
243247 self .token_embs .append (nn .Sequential (
244248 nn .Embedding (num_tokens , fine_dim ),
245249 Rearrange ('... r d -> ... (r d)' ),
246- nn .LayerNorm (seq_len * fine_dim ),
247- nn .Linear (seq_len * fine_dim , dim_out ),
250+ nn .LayerNorm (patch_size * fine_dim ),
251+ nn .Linear (patch_size * fine_dim , dim_out ),
248252 nn .LayerNorm (dim_out )
249253 ))
250254
@@ -268,8 +272,9 @@ def __init__(
268272
269273 if exists (next_h_dim ) and next_h_dim != dim :
270274 proj = nn .Sequential (
275+ Rearrange ('b ... d -> b (...) d' ),
271276 nn .Linear (h_dim , next_h_dim * next_seq_len ),
272- Rearrange ('b m (n d) -> b (m n) d' , n = next_seq_len )
277+ Rearrange ('b m (n d) -> (b m) n d' , n = next_seq_len )
273278 )
274279
275280 self .to_next_transformer_projections .append (proj )
@@ -344,17 +349,18 @@ def forward(self, ids, return_loss = False):
344349 tokens_at_stages = []
345350 pos_embs = default (self .pos_embs , (None ,))
346351
347- for ind , pos_emb , token_emb in zip_longest (range (len (prec_dims )), reversed (pos_embs ), reversed (self .token_embs )):
348- is_last = ind == (len (prec_dims ) - 1 )
352+ for ind , pos_emb , token_emb in zip_longest (range (len (prec_dims )), pos_embs , self .token_embs ):
353+ is_first = ind == 0
354+
349355 tokens = token_emb (ids )
350356
351357 if exists (pos_emb ):
352358 positions = pos_emb (torch .arange (tokens .shape [- 2 ], device = device ))
353359 tokens = tokens + positions
354360
355- tokens_at_stages .append ( tokens )
361+ tokens_at_stages .insert ( 0 , tokens )
356362
357- if is_last :
363+ if is_first :
358364 continue
359365
360366 ids = rearrange (ids , '... m n -> ... (m n)' )
0 commit comments