Skip to content

Commit 3e6180f

Browse files
authored
Backport: Fix activation lookup with Python 3.12.3 (#375) (#377)
* Fix activation lookup with Python 3.12.3 (#375) We used the metaclass `EnumMeta`/`EnumType` to override reporting of missing enum values (to give the full set of supported activations). However, in Python 3.12.3, the default value of the `name` parameter of `EnumType.__call__` method was changed from `None` to `_not_given`: python/cpython@d771729 Even though this is a public API (which now uses a private default value), it seems too risky to continue using it. So in this change, we implement `Enum.__mising__` instead for the improved error reporting. * Set version to 1.3.2 * Adjust two cross-tests for changes in HF transformers (#367) * Fix `test_rotary_embeddings_against_hf` for latest transformers * xfail test because HfFileSystem is currently broken
1 parent b192987 commit 3e6180f

File tree

4 files changed

+14
-43
lines changed

4 files changed

+14
-43
lines changed

curated_transformers/layers/activations.py

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,13 @@
11
import math
2-
from enum import Enum, EnumMeta
2+
from enum import Enum
33
from typing import Type
44

55
import torch
66
from torch import Tensor
77
from torch.nn import Module
88

99

10-
class _ActivationMeta(EnumMeta):
11-
"""
12-
``Enum`` metaclass to override the class ``__call__`` method with a more
13-
fine-grained exception for unknown activation functions.
14-
"""
15-
16-
def __call__(
17-
cls,
18-
value,
19-
names=None,
20-
*,
21-
module=None,
22-
qualname=None,
23-
type=None,
24-
start=1,
25-
):
26-
# Wrap superclass __call__ to give a nicer error message when
27-
# an unknown activation is used.
28-
if names is None:
29-
try:
30-
return EnumMeta.__call__(
31-
cls,
32-
value,
33-
names,
34-
module=module,
35-
qualname=qualname,
36-
type=type,
37-
start=start,
38-
)
39-
except ValueError:
40-
supported_activations = ", ".join(sorted(v.value for v in cls))
41-
raise ValueError(
42-
f"Invalid activation function `{value}`. "
43-
f"Supported functions: {supported_activations}"
44-
)
45-
else:
46-
return EnumMeta.__call__(cls, value, names, module, qualname, type, start)
47-
48-
49-
class Activation(Enum, metaclass=_ActivationMeta):
10+
class Activation(Enum):
5011
"""
5112
Activation functions.
5213
@@ -71,6 +32,14 @@ class Activation(Enum, metaclass=_ActivationMeta):
7132
#: Sigmoid Linear Unit (`Hendrycks et al., 2016`_).
7233
SiLU = "silu"
7334

35+
@classmethod
36+
def _missing_(cls, value):
37+
supported_activations = ", ".join(sorted(v.value for v in cls))
38+
raise ValueError(
39+
f"Invalid activation function `{value}`. "
40+
f"Supported functions: {supported_activations}"
41+
)
42+
7443
@property
7544
def module(self) -> Type[torch.nn.Module]:
7645
"""

curated_transformers/tests/layers/test_embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def test_rotary_embeddings_against_hf(device):
2424

2525
X = torch.rand(16, 12, 64, 768, device=device)
2626
Y = re(X)
27-
hf_re_cos, hf_re_sin = hf_re(X, seq_len=X.shape[-2])
27+
positions = torch.arange(X.shape[2], device=device).view([1, -1])
28+
hf_re_cos, hf_re_sin = hf_re(X, positions)
2829
Y_hf = hf_re_cos * X + hf_re_sin * rotate_half(X)
2930

3031
torch_assertclose(Y, Y_hf)

curated_transformers/tests/tokenizers/test_hf_hub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def test_from_hf_hub_to_cache_legacy():
5151
)
5252

5353

54+
@pytest.mark.xfail(reason="HfFileSystem calls safetensors with incorrect arguments")
5455
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
5556
def test_fsspec(sample_texts):
5657
# We only test one model, since using fsspec downloads the model

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[metadata]
2-
version = 1.3.1
2+
version = 1.3.2
33
description = A PyTorch library of transformer models and components
44
url = https://github.com/explosion/curated-transformers
55
author = Explosion

0 commit comments

Comments
 (0)