Skip to content

Commit 0277167

Browse files
committed
fix and release 0.1.0
1 parent ede6166 commit 0277167

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ def __init__(
271271

272272
if exists(next_h_dim) and next_h_dim != dim:
273273
proj = nn.Sequential(
274-
nn.Linear(h_dim, next_h_dim * next_seq_len),
275-
Rearrange('... (n d) -> (...) n d', n = next_seq_len)
274+
nn.Linear(h_dim, next_h_dim * (next_seq_len + 1)),
275+
Rearrange('... (n d) -> (...) n d', n = next_seq_len + 1)
276276
)
277277

278278
self.to_next_transformer_projections.append(proj)
@@ -379,10 +379,6 @@ def forward(self, ids, return_loss = False):
379379
stage_tokens,
380380
), dim = -2)
381381

382-
# omit last token
383-
384-
stage_tokens = stage_tokens[:, :-1]
385-
386382
# sum the previous hierarchy's representation
387383

388384
if exists(prev_stage_tokens_repr):
@@ -394,11 +390,13 @@ def forward(self, ids, return_loss = False):
394390

395391
# project for next stage in the hierarchy
396392

397-
prev_stage_tokens_repr = proj(attended)
393+
prev_stage_tokens_repr = proj(attended[..., :-1, :])
398394

399395
# project to logits
400396

401-
logits = self.to_logits(attended)
397+
logits = self.to_logits(attended)
398+
399+
logits = logits[..., 1:, :]
402400

403401
if not return_loss:
404402

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.0.10',
6+
version = '0.1.0',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)