Skip to content

Commit f1850d6

Browse files
Feat/test optional dep (onnx, ray, optuna) (#2702)
* feat: adding tests for the optional dependencies * fix: typos * fix: typo * fix: tests, extended to more models * fix: also cover exporting to onnx after laoding ckpt * fix: use_X_covariate attribute is correctly updated after loading weights from checkpoint * update the test of the onnx optional dep, add a util file * reduced code redundancy in the tests * feat: adding tests for optuna * feat: adding tests for ray * fix: simplified test * fix: github actions * tmp fix: remove ray test for regression model * fix: improve test coverage * fix: further simply the tests * address review comments * minor update --------- Co-authored-by: dennisbader <[email protected]>
1 parent 19b17d2 commit f1850d6

File tree

14 files changed

+727
-61
lines changed

14 files changed

+727
-61
lines changed

.github/workflows/develop.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
if [ "${{ matrix.os }}" == "macos-13" ]; then
5252
source $HOME/.local/bin/env
5353
fi
54-
uv pip compile requirements/dev-all.txt > requirements-latest.txt
54+
uv pip compile requirements/dev-all.txt requirements/optional.txt -o requirements-latest.txt
5555
5656
- name: "Cache python environment"
5757
uses: actions/cache@v4
@@ -67,7 +67,7 @@ jobs:
6767
- name: "Install Dependencies"
6868
run: |
6969
# install latest dependencies (potentially updating cached dependencies)
70-
pip install -U -r requirements/dev-all.txt
70+
pip install -U -r requirements/dev-all.txt -r requirements/optional.txt
7171
7272
- name: "Install libomp (for LightGBM)"
7373
run: |
@@ -99,7 +99,7 @@ jobs:
9999
- name: "Compile Dependency Versions"
100100
run: |
101101
curl -LsSf https://astral.sh/uv/install.sh | sh
102-
uv pip compile requirements/dev-all.txt > requirements-latest.txt
102+
uv pip compile requirements/dev-all.txt requirements/optional.txt > requirements-latest.txt
103103
104104
# only restore cache but do not upload
105105
- name: "Restore cached python environment"
@@ -120,7 +120,7 @@ jobs:
120120
- name: "Install Dependencies"
121121
run: |
122122
# install latest dependencies (potentially updating cached dependencies)
123-
pip install -U -r requirements/dev-all.txt
123+
pip install -U -r requirements/dev-all.txt -r requirements/optional.txt
124124
125125
- name: "Install libomp (for LightGBM)"
126126
run: |
@@ -152,7 +152,7 @@ jobs:
152152
- name: "Compile Dependency Versions"
153153
run: |
154154
curl -LsSf https://astral.sh/uv/install.sh | sh
155-
uv pip compile requirements/dev-all.txt > requirements-latest.txt
155+
uv pip compile requirements/dev-all.txt requirements/optional.txt > requirements-latest.txt
156156
157157
# only restore cache but do not upload
158158
- name: "Restore cached python environment"
@@ -169,7 +169,7 @@ jobs:
169169
- name: "Install Dependencies"
170170
run: |
171171
# install latest dependencies (potentially updating cached dependencies)
172-
pip install -U -r requirements/dev-all.txt
172+
pip install -U -r requirements/dev-all.txt -r requirements/optional.txt
173173
174174
- name: "Install libomp (for LightGBM)"
175175
run: |

.github/workflows/merge.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
elif [ "${{ matrix.flavour }}" == "torch" ]; then
5555
pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/dev.txt
5656
elif [ "${{ matrix.flavour }}" == "all" ]; then
57-
pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/notorch.txt -r requirements/dev.txt
57+
pip install -r requirements/core.txt -r requirements/torch.txt -r requirements/notorch.txt -r requirements/optional.txt -r requirements/dev.txt
5858
fi
5959
6060
- name: "Install libomp (for LightGBM)"
@@ -94,7 +94,7 @@ jobs:
9494
- name: "Compile Dependency Versions"
9595
run: |
9696
curl -LsSf https://astral.sh/uv/install.sh | sh
97-
uv pip compile requirements/dev-all.txt > requirements-latest.txt
97+
uv pip compile requirements/dev-all.txt requirements/optional.txt -o requirements-latest.txt
9898
9999
# only restore cache but do not upload
100100
- name: "Restore cached python environment"
@@ -111,7 +111,7 @@ jobs:
111111
- name: "Install Dependencies"
112112
run: |
113113
# install latest dependencies (potentially updating cached dependencies)
114-
pip install -U -r requirements/dev-all.txt
114+
pip install -U -r requirements/dev-all.txt -r requirements/optional.txt
115115
116116
- name: "Install libomp (for LightGBM)"
117117
run: |
@@ -141,7 +141,7 @@ jobs:
141141
- name: "Compile Dependency Versions"
142142
run: |
143143
curl -LsSf https://astral.sh/uv/install.sh | sh
144-
uv pip compile requirements/dev-all.txt > requirements-latest.txt
144+
uv pip compile requirements/dev-all.txt requirements/optional.txt -o requirements-latest.txt
145145
146146
# only restore cache but do not upload
147147
- name: "Restore cached python environment"
@@ -162,7 +162,7 @@ jobs:
162162
- name: "Install Dependencies"
163163
run: |
164164
# install latest dependencies (potentially updating cached dependencies)
165-
pip install -U -r requirements/dev-all.txt
165+
pip install -U -r requirements/dev-all.txt -r requirements/optional.txt
166166
167167
- name: "Install libomp (for LightGBM)"
168168
run: |

.github/workflows/update-cache.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
if [ "${{ matrix.os }}" == "macos-13" ]; then
3232
source $HOME/.local/bin/env
3333
fi
34-
uv pip compile requirements/dev-all.txt > requirements-latest.txt
34+
uv pip compile requirements/dev-all.txt -r requirements/optional.txt -o requirements-latest.txt
3535
3636
- name: "Cache python environment"
3737
uses: actions/cache@v4
@@ -47,4 +47,4 @@ jobs:
4747
- name: "Install Latest Dependencies"
4848
run: |
4949
# install latest dependencies (potentially updating cached dependencies)
50-
pip install -U -r requirements/dev-all.txt
50+
pip install -U -r requirements/dev-all.txt -r requirements/optional.txt

darts/models/forecasting/tft_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def __init__(
726726
If ``False``, only attends to previous time steps in the decoder. If ``True`` attends to previous,
727727
current, and future time steps. Defaults to ``False``.
728728
feed_forward
729-
A feedforward network is a fully-connected layer with an activation. TFT Can be one of the glu variant's
729+
A feedforward network is a fully-connected layer with an activation. Can be one of the glu variant's
730730
FeedForward Network (FFN)[2]. The glu variant's FeedForward Network are a series of FFNs designed to work
731731
better with Transformer based models. Defaults to ``"GatedResidualNetwork"``. ["GLU", "Bilinear", "ReGLU",
732732
"GEGLU", "SwiGLU", "ReLU", "GELU"] or the TFT original FeedForward Network ["GatedResidualNetwork"].

darts/models/forecasting/torch_forecasting_model.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,12 @@ def _verify_past_future_covariates(self, past_covariates, future_covariates):
646646
logger=logger,
647647
)
648648

649+
@abstractmethod
650+
def _update_covariates_use(self):
651+
"""Based on the Forecasting class and the training_sample attribute, update the
652+
uses_[past/future/static]_covariates attributes."""
653+
pass
654+
649655
def to_onnx(self, path: Optional[str] = None, **kwargs):
650656
"""Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
651657
:func:`torch.onnx.export` method (`official documentation <https://lightning.ai/docs/pytorch/
@@ -677,6 +683,8 @@ def to_onnx(self, path: Optional[str] = None, **kwargs):
677683
``input_sample``, ``input_name``). For more information, read the `official documentation
678684
<https://pytorch.org/docs/master/onnx.html#torch.onnx.export>`_.
679685
"""
686+
# TODO: LSTM model should be exported with a batch size of 1
687+
# TODO: predictions with TFT and TCN models is incorrect, might be caused by helper function to process inputs
680688
if not self._fit_called:
681689
raise_log(
682690
ValueError("`fit()` needs to be called before `to_onnx()`."), logger
@@ -2133,6 +2141,9 @@ def load_weights_from_checkpoint(
21332141
self.model.load_state_dict(ckpt["state_dict"], strict=strict)
21342142
# update the fit_called attribute to allow for direct inference
21352143
self._fit_called = True
2144+
# based on the shape of train_sample, figure out which covariates are used by the model
2145+
# (usually set in the Darts model prior to fitting it)
2146+
self._update_covariates_use()
21362147

21372148
def load_weights(
21382149
self, path: str, load_encoders: bool = True, skip_checks: bool = False, **kwargs
@@ -2683,6 +2694,13 @@ def extreme_lags(
26832694
None,
26842695
)
26852696

2697+
def _update_covariates_use(self):
2698+
"""The model is expected to rely on the `PastCovariatesTrainingDataset`"""
2699+
_, past_covs, static_covs, _ = self.train_sample
2700+
self._uses_past_covariates = past_covs is not None
2701+
self._uses_future_covariates = False
2702+
self._uses_static_covariates = static_covs is not None
2703+
26862704

26872705
class FutureCovariatesTorchModel(TorchForecastingModel, ABC):
26882706
supports_past_covariates = False
@@ -2776,6 +2794,13 @@ def extreme_lags(
27762794
None,
27772795
)
27782796

2797+
def _update_covariates_use(self):
2798+
"""The model is expected to rely on the `FutureCovariatesTrainingDataset`"""
2799+
_, future_covs, static_covs, _ = self.train_sample
2800+
self._uses_past_covariates = False
2801+
self._uses_future_covariates = future_covs is not None
2802+
self._uses_static_covariates = static_covs is not None
2803+
27792804

27802805
class DualCovariatesTorchModel(TorchForecastingModel, ABC):
27812806
supports_past_covariates = False
@@ -2870,6 +2895,15 @@ def extreme_lags(
28702895
None,
28712896
)
28722897

2898+
def _update_covariates_use(self):
2899+
"""The model is expected to rely on the `DualCovariatesTrainingDataset`"""
2900+
_, historic_future_covs, future_covs, static_covs, _ = self.train_sample
2901+
self._uses_past_covariates = False
2902+
self._uses_future_covariates = (
2903+
historic_future_covs is not None or future_covs is not None
2904+
)
2905+
self._uses_static_covariates = static_covs is not None
2906+
28732907

28742908
class MixedCovariatesTorchModel(TorchForecastingModel, ABC):
28752909
def _build_train_dataset(
@@ -2964,6 +2998,17 @@ def extreme_lags(
29642998
None,
29652999
)
29663000

3001+
def _update_covariates_use(self):
3002+
"""The model is expected to rely on the `MixedCovariatesTrainingDataset`"""
3003+
_, past_covs, historic_future_covs, future_covs, static_covs, _ = (
3004+
self.train_sample
3005+
)
3006+
self._uses_past_covariates = past_covs is not None
3007+
self._uses_future_covariates = (
3008+
historic_future_covs is not None or future_covs is not None
3009+
)
3010+
self._uses_static_covariates = static_covs is not None
3011+
29673012

29683013
class SplitCovariatesTorchModel(TorchForecastingModel, ABC):
29693014
def _build_train_dataset(
@@ -3058,3 +3103,12 @@ def extreme_lags(
30583103
self.output_chunk_shift,
30593104
None,
30603105
)
3106+
3107+
def _update_covariates_use(self):
3108+
"""The model is expected to rely on the `SplitCovariatesTrainingDataset`"""
3109+
_, past_covs, historic_future_covs, future_covs, static_covs, _ = (
3110+
self.train_sample
3111+
)
3112+
self._uses_past_covariates = past_covs is not None
3113+
self._uses_future_covariates = future_covs is not None
3114+
self._uses_static_covariates = static_covs is not None

darts/tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,31 @@
1717
logger.warning("Torch not installed - Some tests will be skipped.")
1818
TORCH_AVAILABLE = False
1919

20+
try:
21+
import onnx # noqa: F401
22+
import onnxruntime # noqa: F401
23+
24+
ONNX_AVAILABLE = True
25+
except ImportError:
26+
logger.warning("Onnx not installed - Some tests will be skipped.")
27+
ONNX_AVAILABLE = False
28+
29+
try:
30+
import optuna # noqa: F401
31+
32+
OPTUNA_AVAILABLE = True
33+
except ImportError:
34+
logger.warning("Optuna not installed - Some tests will be skipped.")
35+
OPTUNA_AVAILABLE = False
36+
37+
try:
38+
import ray # noqa: F401
39+
40+
RAY_AVAILABLE = True
41+
except ImportError:
42+
logger.warning("Ray not installed - Some tests will be skipped.")
43+
RAY_AVAILABLE = False
44+
2045
tfm_kwargs = {
2146
"pl_trainer_kwargs": {
2247
"accelerator": "cpu",
@@ -25,6 +50,15 @@
2550
}
2651
}
2752

53+
tfm_kwargs_dev = {
54+
"pl_trainer_kwargs": {
55+
"accelerator": "cpu",
56+
"enable_progress_bar": False,
57+
"enable_model_summary": False,
58+
"fast_dev_run": True,
59+
}
60+
}
61+
2862

2963
@pytest.fixture(scope="session", autouse=True)
3064
def set_up_tests(request):

darts/tests/models/forecasting/test_torch_forecasting_model.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from darts.dataprocessing.encoders import SequentialEncoder
1414
from darts.dataprocessing.transformers import BoxCox, Scaler
1515
from darts.metrics import mape
16-
from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs
16+
from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs, tfm_kwargs_dev
1717

1818
if not TORCH_AVAILABLE:
1919
pytest.skip(
@@ -429,6 +429,63 @@ def create_model(**kwargs):
429429
model_new = create_model(**kwargs_)
430430
model_new.load_weights(model_path_manual)
431431

432+
@pytest.mark.parametrize(
433+
"params",
434+
itertools.product(
435+
[DLinearModel, NBEATSModel, RNNModel], # model_cls
436+
[True, False], # past_covs
437+
[True, False], # future_covs
438+
[True, False], # static covs
439+
),
440+
)
441+
def test_save_and_load_weights_covs_usage_attributes(self, tmpdir_fn, params):
442+
"""
443+
Verify that save/load correctly preserve the use_[past/future/static]_covariates attribute.
444+
"""
445+
model_cls, use_pc, use_fc, use_sc = params
446+
model = model_cls(
447+
input_chunk_length=4,
448+
output_chunk_length=1,
449+
n_epochs=1,
450+
**tfm_kwargs_dev,
451+
)
452+
# skip test if the combination of covariates is not supported by the model
453+
if (
454+
(use_pc and not model.supports_past_covariates)
455+
or (use_fc and not model.supports_future_covariates)
456+
or (use_sc and not model.supports_static_covariates)
457+
):
458+
return
459+
460+
model.fit(
461+
series=self.series
462+
if not use_sc
463+
else self.series.with_static_covariates(pd.Series([12], ["loc"])),
464+
past_covariates=self.series + 10 if use_pc else None,
465+
future_covariates=self.series - 5 if use_fc else None,
466+
)
467+
# save and load the model
468+
filename_ckpt = f"{model.model_name}.pt"
469+
model.save(filename_ckpt)
470+
model_loaded = model_cls(
471+
input_chunk_length=4,
472+
output_chunk_length=1,
473+
**tfm_kwargs_dev,
474+
)
475+
model_loaded.load_weights(filename_ckpt)
476+
477+
assert model.uses_past_covariates == model_loaded.uses_past_covariates == use_pc
478+
assert (
479+
model.uses_future_covariates
480+
== model_loaded.uses_future_covariates
481+
== use_fc
482+
)
483+
assert (
484+
model.uses_static_covariates
485+
== model_loaded.uses_static_covariates
486+
== use_sc
487+
)
488+
432489
def test_save_and_load_weights_w_encoders(self, tmpdir_fn):
433490
"""
434491
Verify that save/load does not break encoders.

0 commit comments

Comments
 (0)