diff --git a/pyproject.toml b/pyproject.toml index afa53ffb7d..a3725a13e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -222,6 +222,7 @@ warn_unused_configs = true strict = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] warn_unreachable = true +plugins = "numpy.typing.mypy_plugin" [[tool.mypy.overrides]] module = [ diff --git a/src/pyhf/tensor/manager.py b/src/pyhf/tensor/manager.py index a4a6ec3541..da4b467bb0 100644 --- a/src/pyhf/tensor/manager.py +++ b/src/pyhf/tensor/manager.py @@ -65,6 +65,11 @@ def set_backend( Example: >>> import pyhf + >>> pyhf.set_backend(b"jax", precision="32b") + >>> pyhf.tensorlib.name + 'jax' + >>> pyhf.tensorlib.precision + '32b' >>> pyhf.set_backend(pyhf.tensor.numpy_backend()) >>> pyhf.tensorlib.name 'numpy' diff --git a/tests/conftest.py b/tests/conftest.py index 7b7919f209..8df1ce5b03 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -89,7 +89,7 @@ def reset_backend(): def backend(request): # a better way to get the id? all the backends we have so far for testing param_ids = request._fixturedef.ids - # the backend we're using: numpy, etc... + # the backend we're using: numpy, jax, etc... param_id = param_ids[request.param_index] # name of function being called (with params), the original name is .originalname func_name = request._pyfuncitem.name