11import math
22import functools
3+ from itertools import zip_longest
34
45import torch
56import torch .nn .functional as F
@@ -204,8 +205,8 @@ def __init__(
204205 * ,
205206 num_tokens ,
206207 dim ,
207- depth ,
208- max_seq_len ,
208+ depth : tuple ,
209+ max_seq_len : tuple ,
209210 dim_head = 64 ,
210211 heads = 8 ,
211212 attn_dropout = 0. ,
@@ -225,26 +226,32 @@ def __init__(
225226 assert len (depth ) == len (max_seq_len )
226227
227228 self .stages = len (depth )
229+ dim = cast_tuple (dim , self .stages )
228230
229- self .token_emb = nn .Embedding (num_tokens , dim )
230- self .start_tokens = nn .Parameter (torch .randn (dim ))
231+ assert len (dim ) == self .stages
232+
233+ coarsest_dim , * _ , fine_dim = dim
234+
235+ self .token_emb = nn .Embedding (num_tokens , fine_dim )
236+ self .start_tokens = nn .Parameter (torch .randn (coarsest_dim ))
231237
232238 self .max_seq_len = max_seq_len
233239
234- self .pos_embs = nn .ModuleList ([nn .Embedding (seq_len , dim ) for seq_len in max_seq_len ])
240+ self .pos_embs = nn .ModuleList ([nn .Embedding (seq_len , h_dim ) for h_dim , seq_len in zip ( dim , max_seq_len ) ])
235241
236242 self .patch_embedders = nn .ModuleList ([nn .Sequential (
237243 Rearrange ('... r d -> ... (r d)' ),
238- nn .LayerNorm (seq_len * dim ),
239- nn .Linear (seq_len * dim , dim ),
240- nn .LayerNorm (dim )
241- ) for seq_len in self . max_seq_len [1 :]])
244+ nn .LayerNorm (seq_len * dim_in ),
245+ nn .Linear (seq_len * dim_in , dim_out ),
246+ nn .LayerNorm (dim_out )
247+ ) for dim_in , dim_out , seq_len in zip ( dim [ 1 :], dim [: - 1 ], max_seq_len [1 :]) ])
242248
243249 self .transformers = nn .ModuleList ([])
250+ self .to_next_transformer_projections = nn .ModuleList ([])
244251
245- for stage_depth in depth :
252+ for h_dim , next_h_dim , stage_depth in zip_longest ( dim , dim [ 1 :], depth ) :
246253 self .transformers .append (Transformer (
247- dim = dim ,
254+ dim = h_dim ,
248255 layers = stage_depth ,
249256 dim_head = dim_head ,
250257 heads = heads ,
@@ -255,7 +262,10 @@ def __init__(
255262 flash_attn = flash_attn
256263 ))
257264
258- self .to_logits = nn .Linear (dim , num_tokens )
265+ proj = nn .Linear (h_dim , next_h_dim ) if exists (next_h_dim ) and next_h_dim != dim else nn .Identity ()
266+ self .to_next_transformer_projections .append (proj )
267+
268+ self .to_logits = nn .Linear (fine_dim , num_tokens )
259269 self .pad_id = pad_id
260270
261271 def generate (self , prime = None , filter_thres = 0.9 , temperature = 1. , default_batch_size = 1 ):
@@ -339,7 +349,7 @@ def forward(self, ids, return_loss = False):
339349
340350 # spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions
341351
342- for ind , (stage_tokens , transformer ) in enumerate (zip (tokens_at_stages , self .transformers )):
352+ for ind , (stage_tokens , transformer , proj ) in enumerate (zip (tokens_at_stages , self .transformers , self . to_next_transformer_projections )):
343353 is_last = ind == (self .stages - 1 )
344354
345355 stage_tokens = torch .cat ((
@@ -348,7 +358,10 @@ def forward(self, ids, return_loss = False):
348358 ), dim = - 2 )
349359
350360 stage_tokens , ps = pack_one (stage_tokens , '* n d' )
361+
351362 attended = transformer (stage_tokens )
363+ attended = proj (attended )
364+
352365 attended = unpack_one (attended , ps , '* n d' )
353366
354367 start_tokens = rearrange (attended [..., :- 1 , :], '... n d -> ... n 1 d' )
0 commit comments