Skip to content

Commit 549d9db

Browse files
authored
Devops/bump min lightning 200 (#2888)
* bump minimum lightning version to >=2.0.0 * update changelog
1 parent db930a8 commit 549d9db

File tree

4 files changed

+9
-23
lines changed

4 files changed

+9
-23
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1818

1919
**Dependencies**
2020

21+
- We raised the minimum pytorch-lightning version to `pytorch-lightning>=2.0.0`. [#2888](https://github.com/unit8co/darts/pull/2888) by [Dennis Bader](https://github.com/dennisbader).
22+
2123
### For developers of the library:
2224

2325
## [0.37.1](https://github.com/unit8co/darts/tree/0.37.1) (2025-08-18)

darts/models/forecasting/torch_forecasting_model.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
import pytorch_lightning as pl
3939
import torch
4040
from pytorch_lightning import loggers as pl_loggers
41+
from pytorch_lightning.callbacks import ProgressBar
42+
from pytorch_lightning.tuner import Tuner
4143
from torch.utils.data import DataLoader
4244

4345
from darts import TimeSeries
@@ -78,18 +80,6 @@
7880
from darts.utils.ts_utils import get_single_series, seq2series, series2seq
7981
from darts.utils.utils import _build_tqdm_iterator, _parallel_apply
8082

81-
# Check whether we are running pytorch-lightning >= 2.0.0 or not:
82-
tokens = pl.__version__.split(".")
83-
pl_200_or_above = int(tokens[0]) >= 2
84-
85-
if pl_200_or_above:
86-
from pytorch_lightning.callbacks import ProgressBar
87-
from pytorch_lightning.tuner import Tuner
88-
else:
89-
from pytorch_lightning.callbacks import ProgressBarBase as ProgressBar
90-
from pytorch_lightning.tuner.tuning import Tuner
91-
92-
9383
DEFAULT_DARTS_FOLDER = "darts_logs"
9484
CHECKPOINTS_FOLDER = "checkpoints"
9585
RUNS_FOLDER = "runs"
@@ -450,10 +440,10 @@ def _init_model(self, trainer: Optional[pl.Trainer] = None) -> PLForecastingModu
450440
dtype = self.train_sample[0].dtype
451441
if np.issubdtype(dtype, np.float32):
452442
logger.info("Time series values are 32-bits; casting model to float32.")
453-
precision = "32" if not pl_200_or_above else "32-true"
443+
precision = "32-true"
454444
elif np.issubdtype(dtype, np.float64):
455445
logger.info("Time series values are 64-bits; casting model to float64.")
456-
precision = "64" if not pl_200_or_above else "64-true"
446+
precision = "64-true"
457447
else:
458448
raise_log(
459449
ValueError(
@@ -471,9 +461,7 @@ def _init_model(self, trainer: Optional[pl.Trainer] = None) -> PLForecastingModu
471461
)
472462
if precision_user is not None:
473463
# currently, we only support float 64 and 32
474-
valid_precisions = (
475-
["64", "32"] if not pl_200_or_above else ["64-true", "32-true"]
476-
)
464+
valid_precisions = ["64-true", "32-true"]
477465
if str(precision_user) not in valid_precisions:
478466
raise_log(
479467
ValueError(

darts/tests/models/forecasting/test_ptl_trainer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,7 @@ class TestPTLTrainer:
2121
"enable_checkpointing": False,
2222
}
2323
series = linear_timeseries(length=100).astype(np.float32)
24-
pl_200_or_above = int(pl.__version__.split(".")[0]) >= 2
25-
precisions = {
26-
32: "32" if not pl_200_or_above else "32-true",
27-
64: "64" if not pl_200_or_above else "64-true",
28-
}
24+
precisions = {32: "32-true", 64: "64-true"}
2925

3026
def test_prediction_loaded_custom_trainer(self, tmpdir_module):
3127
"""validate manual save with automatic save files by comparing output between the two"""

requirements/torch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
pytorch-lightning>=1.5.0,<2.5.3
1+
pytorch-lightning>=2.0.0,<2.5.3
22
tensorboardX>=2.1
33
torch>=1.8.0

0 commit comments

Comments
 (0)