@@ -271,8 +271,8 @@ def __init__(
271271
272272 if exists (next_h_dim ) and next_h_dim != dim :
273273 proj = nn .Sequential (
274- nn .Linear (h_dim , next_h_dim * next_seq_len ),
275- Rearrange ('... (n d) -> (...) n d' , n = next_seq_len )
274+ nn .Linear (h_dim , next_h_dim * ( next_seq_len + 1 ) ),
275+ Rearrange ('... (n d) -> (...) n d' , n = next_seq_len + 1 )
276276 )
277277
278278 self .to_next_transformer_projections .append (proj )
@@ -379,10 +379,6 @@ def forward(self, ids, return_loss = False):
379379 stage_tokens ,
380380 ), dim = - 2 )
381381
382- # omit last token
383-
384- stage_tokens = stage_tokens [:, :- 1 ]
385-
386382 # sum the previous hierarchy's representation
387383
388384 if exists (prev_stage_tokens_repr ):
@@ -394,11 +390,13 @@ def forward(self, ids, return_loss = False):
394390
395391 # project for next stage in the hierarchy
396392
397- prev_stage_tokens_repr = proj (attended )
393+ prev_stage_tokens_repr = proj (attended [..., : - 1 , :] )
398394
399395 # project to logits
400396
401- logits = self .to_logits (attended )
397+ logits = self .to_logits (attended )
398+
399+ logits = logits [..., 1 :, :]
402400
403401 if not return_loss :
404402
0 commit comments