Skip to content

Commit 56dde1c

Browse files
committed
drop torch
1 parent a76b4be commit 56dde1c

File tree

13 files changed

+10
-743
lines changed

13 files changed

+10
-743
lines changed

.zenodo.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"fitting",
2929
"scipy",
3030
"numpy",
31-
"pytorch",
3231
"jax",
3332
"auto-differentiation"
3433
],

CITATION.cff

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ keywords:
2626
- fitting
2727
- scipy
2828
- numpy
29-
- pytorch
3029
- jax
3130
- auto-differentiation
3231
license: "Apache-2.0"
@@ -37,7 +36,7 @@ abstract: |
3736
of that statistical model for multi-bin histogram-based analysis and its
3837
interval estimation is based on the asymptotic formulas of "Asymptotic
3938
formulae for likelihood-based tests of new physics". pyhf supports modern
40-
computational graph libraries such as PyTorch and JAX in order
39+
computational graph libraries such as JAX in order
4140
to make use of features such as autodifferentiation and GPU acceleration.
4241
references:
4342
- type: article

docker/gpu/install_backend.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ function get_JAXLIB_GPU_WHEEL {
1818
function install_backend() {
1919
# 1: the backend option name in setup.py
2020
local backend="${1}"
21-
if [[ "${backend}" == "torch" ]]; then
22-
# shellcheck disable=SC2102
23-
python3 -m pip install --no-cache-dir .[xmlio,torch]
24-
elif [[ "${backend}" == "jax" ]]; then
21+
if [[ "${backend}" == "jax" ]]; then
2522
python3 -m pip install --no-cache-dir .[xmlio]
2623
python3 -m pip install --no-cache-dir "$(get_JAXLIB_GPU_WHEEL)"
2724
python3 -m pip install --no-cache-dir jax

docs/api.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ The computational backends that :code:`pyhf` provides interfacing for the vector
6565
:nosignatures:
6666

6767
numpy_backend.numpy_backend
68-
pytorch_backend.pytorch_backend
6968
jax_backend.jax_backend
7069

7170
Optimizers

docs/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def setup(app):
146146
# today_fmt = '%B %d, %Y'
147147

148148
autodoc_mock_imports = [
149-
'torch',
150149
'jax',
151150
'iminuit',
152151
]

docs/installation.rst

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@ Install latest stable release from `PyPI <https://pypi.org/project/pyhf/>`__...
2727
2828
python -m pip install pyhf
2929
30-
... with PyTorch backend
31-
++++++++++++++++++++++++
32-
33-
.. code-block:: console
34-
35-
python -m pip install 'pyhf[torch]'
36-
3730
... with JAX backend
3831
++++++++++++++++++++
3932

@@ -67,13 +60,6 @@ Install latest development version from `GitHub <https://github.com/scikit-hep/p
6760
6861
python -m pip install --upgrade 'pyhf@git+https://github.com/scikit-hep/pyhf.git'
6962
70-
... with PyTorch backend
71-
++++++++++++++++++++++++
72-
73-
.. code-block:: console
74-
75-
python -m pip install --upgrade 'pyhf[torch]@git+https://github.com/scikit-hep/pyhf.git'
76-
7763
... with JAX backend
7864
++++++++++++++++++++++
7965

pyproject.toml

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ keywords = [
2323
"jax",
2424
"numpy",
2525
"physics",
26-
"pytorch",
2726
"scipy",
2827
]
2928
classifiers = [
@@ -67,10 +66,6 @@ Homepage = "https://github.com/scikit-hep/pyhf"
6766

6867
[project.optional-dependencies]
6968
shellcomplete = ["click_completion"]
70-
torch = [
71-
"torch>=1.10.0", # c.f. PR #1657
72-
"numpy<2.0" # c.f. https://github.com/pytorch/pytorch/issues/157973
73-
]
7469
jax = [
7570
"jax>=0.4.1", # c.f. PR #2079
7671
"jaxlib>=0.4.1", # c.f. PR #2079
@@ -181,18 +176,12 @@ markers = [
181176
"fail_jax",
182177
"fail_numpy",
183178
"fail_numpy_minuit",
184-
"fail_pytorch",
185-
"fail_pytorch64",
186179
"only_jax",
187180
"only_numpy",
188181
"only_numpy_minuit",
189-
"only_pytorch",
190-
"only_pytorch64",
191182
"skip_jax",
192183
"skip_numpy",
193184
"skip_numpy_minuit",
194-
"skip_pytorch",
195-
"skip_pytorch64",
196185
]
197186
filterwarnings = [
198187
"error",
@@ -201,8 +190,6 @@ filterwarnings = [
201190
'ignore: Exception ignored in:pytest.PytestUnraisableExceptionWarning', #FIXME: Exception ignored in: <_io.FileIO [closed]>
202191
'ignore:invalid value encountered in (true_)?divide:RuntimeWarning', #FIXME
203192
'ignore:invalid value encountered in add:RuntimeWarning', #FIXME
204-
"ignore:In future, it will be an error for 'np.bool_' scalars to be interpreted as an index:DeprecationWarning", #FIXME: tests/test_tensor.py::test_pdf_eval[pytorch]
205-
'ignore:Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with:UserWarning', #FIXME: tests/test_optim.py::test_minimize[no_grad-scipy-pytorch-no_stitch]
206193
'ignore:divide by zero encountered in (true_)?divide:RuntimeWarning', #FIXME: pytest tests/test_tensor.py::test_pdf_calculations[numpy]
207194
'ignore:[A-Z]+ is deprecated and will be removed in Pillow 10:DeprecationWarning', # keras
208195
"ignore:ml_dtypes.float8_e4m3b11 is deprecated.", #FIXME: Can remove when jaxlib>=0.4.12
@@ -212,8 +199,6 @@ filterwarnings = [
212199
"ignore:'MultiCommand' is deprecated and will be removed in Click 9.0. Use 'Group' instead.:DeprecationWarning", # Click
213200
"ignore:Jupyter is migrating its paths to use standard platformdirs:DeprecationWarning", # papermill
214201
"ignore:datetime.datetime.utcnow\\(\\) is deprecated:DeprecationWarning", # papermill
215-
"ignore:In future, it will be an error for 'np.bool' scalars to be interpreted as an index:DeprecationWarning", # PyTorch
216-
"ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments.:DeprecationWarning", # PyTorch interacting with NumPy
217202
]
218203

219204
[tool.coverage.run]
@@ -243,7 +228,6 @@ module = [
243228
'jax.*',
244229
'matplotlib.*',
245230
'scipy.*',
246-
'torch.*',
247231
'uproot.*',
248232
]
249233
ignore_missing_imports = true
@@ -273,7 +257,6 @@ module = [
273257
'pyhf.tensor.common.*',
274258
'pyhf.tensor',
275259
'pyhf.tensor.jax_backend.*',
276-
'pyhf.tensor.pytorch_backend.*',
277260
]
278261
ignore_errors = true
279262

src/pyhf/cli/infer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def cli():
3737
)
3838
@click.option(
3939
"--backend",
40-
type=click.Choice(["numpy", "pytorch", "jax", "np", "torch"]),
40+
type=click.Choice(["numpy", "jax", "np"]),
4141
help="The tensor backend used for the calculation.",
4242
default="numpy",
4343
)
@@ -82,9 +82,7 @@ def fit(
8282
}
8383
"""
8484
# set the backend if not NumPy
85-
if backend in ["pytorch", "torch"]:
86-
set_backend("pytorch", precision="64b")
87-
elif backend in ["jax"]:
85+
if backend in ["jax"]:
8886
set_backend("jax")
8987
tensorlib, _ = get_backend()
9088

@@ -149,7 +147,7 @@ def fit(
149147
)
150148
@click.option(
151149
'--backend',
152-
type=click.Choice(['numpy', 'pytorch', 'jax', 'np', 'torch']),
150+
type=click.Choice(['numpy', 'jax', 'np']),
153151
help='The tensor backend used for the calculation.',
154152
default='numpy',
155153
)
@@ -212,9 +210,7 @@ def cls(
212210
)
213211

214212
# set the backend if not NumPy
215-
if backend in ['pytorch', 'torch']:
216-
set_backend("pytorch", precision="64b")
217-
elif backend in ['jax']:
213+
if backend in ['jax']:
218214
set_backend("jax")
219215
tensorlib, _ = get_backend()
220216

src/pyhf/optimize/common.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,6 @@ def _get_tensor_shim():
4242

4343
return numpy_shim
4444

45-
if tensorlib.name == 'pytorch':
46-
from pyhf.optimize.opt_pytorch import wrap_objective as pytorch_shim
47-
48-
return pytorch_shim
49-
5045
if tensorlib.name == 'jax':
5146
from pyhf.optimize.opt_jax import wrap_objective as jax_shim
5247

src/pyhf/optimize/opt_pytorch.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)