Skip to content

Commit 2889898

Browse files
committed
learn on the very first start token
1 parent 93d8a45 commit 2889898

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,23 +397,26 @@ def forward(self, ids, return_loss = False):
397397

398398
logits = self.to_logits(attended)
399399

400-
logits = logits[..., 1:, :]
400+
start_tokens, logits = logits[:, 0, :1, :], logits[..., 1:, :]
401401

402402
if not return_loss:
403403

404404
if flattened_dims:
405-
logits = rearrange(logits, 'b ... n -> b (...) n')
405+
logits = rearrange(logits, 'b ... c -> b (...) c')
406406
logits = logits[:, :seq_len]
407407

408408
return logits
409409

410-
preds = rearrange(logits, 'b ... c -> b c (...)')
410+
logits = rearrange(logits, 'b ... c -> b (...) c')
411+
logits = torch.cat((start_tokens, logits), dim = -2)
412+
413+
preds = rearrange(logits, 'b n c -> b c n')
411414
labels = rearrange(ids, 'b ... -> b (...)')
412415

413416
loss = F.cross_entropy(
414417
preds[..., :-1],
415-
labels[..., 1:],
418+
labels,
416419
ignore_index = self.pad_id
417420
)
418421

419-
return loss
422+
return loss

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.1.1',
6+
version = '0.1.2',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)