Skip to content

Commit 60995d1

Browse files
committed
allow for smaller model dimensions for the finer hierarchical stages
1 parent 2355831 commit 60995d1

File tree

4 files changed

+30
-17
lines changed

4 files changed

+30
-17
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import functools
3+
from itertools import zip_longest
34

45
import torch
56
import torch.nn.functional as F
@@ -204,8 +205,8 @@ def __init__(
204205
*,
205206
num_tokens,
206207
dim,
207-
depth,
208-
max_seq_len,
208+
depth: tuple,
209+
max_seq_len: tuple,
209210
dim_head = 64,
210211
heads = 8,
211212
attn_dropout = 0.,
@@ -225,26 +226,32 @@ def __init__(
225226
assert len(depth) == len(max_seq_len)
226227

227228
self.stages = len(depth)
229+
dim = cast_tuple(dim, self.stages)
228230

229-
self.token_emb = nn.Embedding(num_tokens, dim)
230-
self.start_tokens = nn.Parameter(torch.randn(dim))
231+
assert len(dim) == self.stages
232+
233+
coarsest_dim, *_, fine_dim = dim
234+
235+
self.token_emb = nn.Embedding(num_tokens, fine_dim)
236+
self.start_tokens = nn.Parameter(torch.randn(coarsest_dim))
231237

232238
self.max_seq_len = max_seq_len
233239

234-
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, dim) for seq_len in max_seq_len])
240+
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)])
235241

236242
self.patch_embedders = nn.ModuleList([nn.Sequential(
237243
Rearrange('... r d -> ... (r d)'),
238-
nn.LayerNorm(seq_len * dim),
239-
nn.Linear(seq_len * dim, dim),
240-
nn.LayerNorm(dim)
241-
) for seq_len in self.max_seq_len[1:]])
244+
nn.LayerNorm(seq_len * dim_in),
245+
nn.Linear(seq_len * dim_in, dim_out),
246+
nn.LayerNorm(dim_out)
247+
) for dim_in, dim_out, seq_len in zip(dim[1:], dim[:-1], max_seq_len[1:])])
242248

243249
self.transformers = nn.ModuleList([])
250+
self.to_next_transformer_projections = nn.ModuleList([])
244251

245-
for stage_depth in depth:
252+
for h_dim, next_h_dim, stage_depth in zip_longest(dim, dim[1:], depth):
246253
self.transformers.append(Transformer(
247-
dim = dim,
254+
dim = h_dim,
248255
layers = stage_depth,
249256
dim_head = dim_head,
250257
heads = heads,
@@ -255,7 +262,10 @@ def __init__(
255262
flash_attn = flash_attn
256263
))
257264

258-
self.to_logits = nn.Linear(dim, num_tokens)
265+
proj = nn.Linear(h_dim, next_h_dim) if exists(next_h_dim) and next_h_dim != dim else nn.Identity()
266+
self.to_next_transformer_projections.append(proj)
267+
268+
self.to_logits = nn.Linear(fine_dim, num_tokens)
259269
self.pad_id = pad_id
260270

261271
def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
@@ -339,7 +349,7 @@ def forward(self, ids, return_loss = False):
339349

340350
# spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions
341351

342-
for ind, (stage_tokens, transformer) in enumerate(zip(tokens_at_stages, self.transformers)):
352+
for ind, (stage_tokens, transformer, proj) in enumerate(zip(tokens_at_stages, self.transformers, self.to_next_transformer_projections)):
343353
is_last = ind == (self.stages - 1)
344354

345355
stage_tokens = torch.cat((
@@ -348,7 +358,10 @@ def forward(self, ids, return_loss = False):
348358
), dim = -2)
349359

350360
stage_tokens, ps = pack_one(stage_tokens, '* n d')
361+
351362
attended = transformer(stage_tokens)
363+
attended = proj(attended)
364+
352365
attended = unpack_one(attended, ps, '* n d')
353366

354367
start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d')

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ from MEGABYTE_pytorch import MEGABYTE
2424

2525
model = MEGABYTE(
2626
num_tokens = 16000, # number of tokens
27-
dim = 512, # transformer model dimension
27+
dim = (512, 256), # transformer model dimension (512 for coarsest, 256 for fine in this example)
2828
max_seq_len = (1024, 4), # sequence length for global and then local. this can be more than 2
2929
depth = (6, 4), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
3030
dim_head = 64, # dimension per head
@@ -49,7 +49,7 @@ sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
4949

5050
## Test
5151

52-
Train on character-level enwik8 with patches of size 4
52+
Train on character-level enwik8 with patches of size 4 - length 4096
5353

5454
```bash
5555
$ python train.py

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.0.7',
6+
version = '0.0.9',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def decode_tokens(tokens):
3838

3939
model = MEGABYTE(
4040
num_tokens = 256,
41-
dim = 512,
41+
dim = (512, 512),
4242
depth = (6, 2),
4343
max_seq_len = (1024, 4),
4444
flash_attn = True

0 commit comments

Comments
 (0)