Skip to content

Commit cdfa143

Browse files
committed
add beartype
1 parent 60995d1 commit cdfa143

File tree

4 files changed

+16
-11
lines changed

4 files changed

+16
-11
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from einops import rearrange, reduce, repeat, pack, unpack
1010
from einops.layers.torch import Rearrange
1111

12+
from beartype import beartype
13+
from beartype.typing import Tuple, Union
14+
1215
from MEGABYTE_pytorch.attend import Attend
1316

1417
from tqdm import tqdm
@@ -200,13 +203,15 @@ def forward(self, x):
200203
# main class
201204

202205
class MEGABYTE(nn.Module):
206+
207+
@beartype
203208
def __init__(
204209
self,
205210
*,
206211
num_tokens,
207-
dim,
208-
depth: tuple,
209-
max_seq_len: tuple,
212+
dim: Union[Tuple, int],
213+
depth: Tuple,
214+
max_seq_len: Tuple,
210215
dim_head = 64,
211216
heads = 8,
212217
attn_dropout = 0.,

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
4949

5050
## Test
5151

52-
Train on character-level enwik8 with patches of size 4 - length 4096
52+
Train on character-level enwik8 with patches of size 4 - length 8192
5353

5454
```bash
5555
$ python train.py

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.0.9',
6+
version = '0.0.10',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',
@@ -16,6 +16,7 @@
1616
'transformers'
1717
],
1818
install_requires=[
19+
'beartype',
1920
'einops>=0.6.1',
2021
'torch>=1.10',
2122
'tqdm'

train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
LEARNING_RATE = 2e-4
1818
VALIDATE_EVERY = 100
1919
GENERATE_EVERY = 500
20-
GENERATE_LENGTH = 1024
2120
PRIME_LEN = 100
22-
SEQ_LEN = 1024
21+
SEQ_LEN = 8192
2322

2423
# helpers
2524

@@ -38,9 +37,9 @@ def decode_tokens(tokens):
3837

3938
model = MEGABYTE(
4039
num_tokens = 256,
41-
dim = (512, 512),
42-
depth = (6, 2),
43-
max_seq_len = (1024, 4),
40+
dim = (768, 512, 256),
41+
depth = (6, 4, 2),
42+
max_seq_len = (512, 4, 4),
4443
flash_attn = True
4544
).cuda()
4645

@@ -94,7 +93,7 @@ def __len__(self):
9493
loss = model(next(val_loader), return_loss = True)
9594
print(f'validation loss: {loss.item()}')
9695

97-
if i % GENERATE_EVERY == 0:
96+
if i != 0 and i % GENERATE_EVERY == 0:
9897
model.eval()
9998
inp = random.choice(val_dataset)[:-1]
10099
prime_inp = inp[:PRIME_LEN]

0 commit comments

Comments
 (0)