Skip to content

Commit 23b2b96

Browse files
committed
fix regression and some dimension conditional
1 parent e98b62d commit 23b2b96

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

MEGABYTE_pytorch/megabyte.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def __init__(
275275

276276
proj = nn.Identity()
277277

278-
if exists(next_h_dim) and next_h_dim != dim:
278+
if exists(next_h_dim):
279279
proj = nn.Sequential(
280280
Rearrange('b ... d -> b (...) d'),
281281
nn.Linear(h_dim, next_h_dim * next_seq_len),
@@ -392,7 +392,7 @@ def forward(self, ids, return_loss = False):
392392
# sum the previous hierarchy's representation
393393

394394
if exists(prev_stage_tokens_repr):
395-
prev_stage_tokens_repr = F.pad(prev_stage_tokens_repr, (0, 0, 1, 0), value = self.pad_id)
395+
prev_stage_tokens_repr = F.pad(prev_stage_tokens_repr, (0, 0, 1, 0), value = 0.)
396396
stage_tokens = stage_tokens + prev_stage_tokens_repr
397397

398398
attended = transformer(stage_tokens)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'MEGABYTE-pytorch',
55
packages = find_packages(),
6-
version = '0.3.3',
6+
version = '0.3.4',
77
license='MIT',
88
description = 'MEGABYTE - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)