Skip to content

Commit 67ea666

Browse files
authored
Fix/tfm predict after train from ds (#2860)
* fix tfm predict methods after being trained with fit_from_dataset using static covariates * update changelog * fix failing test
1 parent c93f512 commit 67ea666

File tree

4 files changed

+207
-8
lines changed

4 files changed

+207
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2323
- Fixed a bug in `LightGBMModel` and `CatBoostModel` when using component-specific lags and categorical features, where certain lag scenarios could result in incorrect categorical feature declaration. [#2852](https://github.com/unit8co/darts/pull/2852) by [Dennis Bader](https://github.com/dennisbader).
2424
- Fixed a bug in `darts.utils.timeseries_generation.sine_timeseries()`, where the returned series ignored the specified `dtype`. [#2856](https://github.com/unit8co/darts/pull/2856) by [Dennis Bader](https://github.com/dennisbader).
2525
- Fixed a bug in `TimeSeries.__getitem__()`, where indexing with a list of integers of `length <= 2` resulted in an error. [#2857](https://github.com/unit8co/darts/pull/2857) by [Dennis Bader](https://github.com/dennisbader).
26+
- Fixed a bug in `TorchForecastingModel` which raised an error when calling any predict method after training the model with `fit_from_dataset()` on a dataset that uses static covariates. [#2860](https://github.com/unit8co/darts/pull/2860) by [Dennis Bader](https://github.com/dennisbader).
2627
- Removed `darts/tests` and `examples` from the Darts package distribution. These are only required for internal testing. [#2854](https://github.com/unit8co/darts/pull/2854) by [Dennis Bader](https://github.com/dennisbader).
2728

2829
**Dependencies**

darts/models/forecasting/torch_forecasting_model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,8 +737,17 @@ def _update_covariates_use(self):
737737
_, past_cov, historic_future_cov, future_cov, static_cov, _ = self.train_sample
738738

739739
self._uses_past_covariates = past_cov is not None
740+
self._expect_past_covariates = (
741+
self.uses_past_covariates and self.past_covariate_series is None
742+
)
740743
self._uses_future_covariates = future_cov is not None
744+
self._expect_future_covariates = (
745+
self.uses_future_covariates and self.future_covariate_series is None
746+
)
741747
self._uses_static_covariates = static_cov is not None
748+
self._expect_static_covariates = (
749+
self.uses_static_covariates and self.static_covariates is None
750+
)
742751

743752
def to_onnx(self, path: Optional[str] = None, **kwargs):
744753
"""Export model to ONNX format for optimized inference, wrapping around PyTorch Lightning's
@@ -1204,6 +1213,10 @@ def _setup_for_train(
12041213
f" provided input/output dimensions = {sample_shapes}",
12051214
)
12061215

1216+
# update the covariates usage based on the training sample (required if model training was called
1217+
# with `fit_from_dataset()`)
1218+
self._update_covariates_use()
1219+
12071220
# loss must not reduce the output when using sample weight
12081221
train_sample_weight = train_sample[-2]
12091222
val_sample_weight = val_dataset[0][-2] if val_dataset is not None else None
@@ -1581,8 +1594,10 @@ def predict(
15811594
if self.training_series is None:
15821595
raise_log(
15831596
ValueError(
1584-
"Input `series` must be provided. This is the result either from fitting on multiple series, "
1585-
"from not having fit the model yet, or from loading a model saved with `clean=True`."
1597+
"Input `series` must be provided. This is the result either from "
1598+
"fitting on multiple series, from fitting with `fit_from_dataset()`, "
1599+
"from not having fit the model yet, or from loading a model saved with "
1600+
"`clean=True`."
15861601
),
15871602
logger,
15881603
)

darts/tests/models/forecasting/test_global_forecasting_models.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,18 @@ def test_save_load_model(self, tmpdir_fn, model):
358358
# The serie to predict need to be provided at prediction time
359359
with pytest.raises(ValueError) as err:
360360
loaded_model_clean_str.predict(self.forecasting_horizon)
361-
assert str(err.value) == (
362-
"Input `series` must be provided. This is the result either from fitting on multiple series, "
363-
"from not having fit the model yet, or from loading a model saved with `clean=True`."
364-
)
361+
if isinstance(model, TorchForecastingModel):
362+
assert str(err.value) == (
363+
"Input `series` must be provided. This is the result either from fitting on multiple series, "
364+
"from fitting with `fit_from_dataset()`, from not having fit the model yet, or from loading a "
365+
"model saved with `clean=True`."
366+
)
367+
else:
368+
assert str(err.value) == (
369+
"Input `series` must be provided. This is the result either from fitting on multiple series, "
370+
"from not having fit the model yet, or from loading a "
371+
"model saved with `clean=True`."
372+
)
365373

366374
# When the serie to predict is provided, the prediction is the same
367375
assert model_prediction == loaded_model_clean_str.predict(

darts/tests/models/forecasting/test_torch_forecasting_model.py

Lines changed: 177 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
from darts.dataprocessing.transformers import BoxCox, Scaler
1616
from darts.metrics import mape
1717
from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs, tfm_kwargs_dev
18+
from darts.utils.data.torch_datasets.inference_dataset import (
19+
SequentialTorchInferenceDataset,
20+
)
21+
from darts.utils.data.torch_datasets.training_dataset import (
22+
SequentialTorchTrainingDataset,
23+
)
1824

1925
if not TORCH_AVAILABLE:
2026
pytest.skip(
@@ -232,7 +238,8 @@ def on_train_epoch_end(self, trainer, pl_module):
232238
no_train_model.predict(n=4)
233239
assert str(err.value) == (
234240
"Input `series` must be provided. This is the result either from fitting on multiple series, "
235-
"from not having fit the model yet, or from loading a model saved with `clean=True`."
241+
"from fitting with `fit_from_dataset()`, from not having fit the model yet, or from loading a "
242+
"model saved with `clean=True`."
236243
)
237244

238245
model_manual_save.fit(self.series, epochs=1)
@@ -284,7 +291,8 @@ def on_train_epoch_end(self, trainer, pl_module):
284291
model_manual_save.predict(n=4)
285292
assert str(err.value) == (
286293
"Input `series` must be provided. This is the result either from fitting on multiple series, "
287-
"from not having fit the model yet, or from loading a model saved with `clean=True`."
294+
"from fitting with `fit_from_dataset()`, from not having fit the model yet, or from loading a "
295+
"model saved with `clean=True`."
288296
)
289297
# Predicting while giving the training series in args should yield same prediction
290298
assert model_manual_save.predict(
@@ -2374,6 +2382,173 @@ def test_fit_with_stride(self, stride):
23742382
assert len(train_set) == len(val_set) == math.ceil(3 / stride)
23752383
assert train_set.stride == val_set.stride == stride
23762384

2385+
def test_predict_after_fit_from_dataset(self):
2386+
"""Test that the model can predict after being trained with `fit_from_dataset` using all covariates."""
2387+
icl, ocl = kwargs["input_chunk_length"], kwargs["output_chunk_length"]
2388+
n = 1
2389+
series = [
2390+
self.series[: icl + ocl].with_static_covariates(pd.DataFrame({"sc": [0.0]}))
2391+
]
2392+
pc = [self.series[: icl + ocl]]
2393+
fc = [self.series[: icl + ocl + n]]
2394+
2395+
model = TiDEModel(**kwargs)
2396+
2397+
train_dataset = SequentialTorchTrainingDataset(
2398+
series=series,
2399+
past_covariates=pc,
2400+
future_covariates=fc,
2401+
input_chunk_length=icl,
2402+
output_chunk_length=ocl,
2403+
use_static_covariates=True,
2404+
)
2405+
2406+
# check training works and covariates are used
2407+
model.fit_from_dataset(train_dataset=train_dataset)
2408+
assert model.uses_past_covariates
2409+
assert model.uses_future_covariates
2410+
assert model.uses_static_covariates
2411+
assert model._expect_past_covariates
2412+
assert model._expect_future_covariates
2413+
assert model._expect_static_covariates
2414+
2415+
with pytest.raises(ValueError) as exc:
2416+
_ = model.predict(n=n)
2417+
assert str(exc.value).startswith("Input `series` must be provided.")
2418+
2419+
hfc_kwargs = {"forecast_horizon": n, "retrain": False, "overlap_end": True}
2420+
self.helper_predict_raise_on_missing_input(
2421+
model, "predict", series, pc, fc, n=n
2422+
)
2423+
self.helper_predict_raise_on_missing_input(
2424+
model, "historical_forecasts", series, pc, fc, **hfc_kwargs
2425+
)
2426+
self.helper_predict_from_ds_raise_on_missing_input(
2427+
model,
2428+
series,
2429+
pc,
2430+
fc,
2431+
n=n,
2432+
input_chunk_length=icl,
2433+
output_chunk_length=ocl,
2434+
)
2435+
2436+
# check predict methods
2437+
inference_dataset = SequentialTorchInferenceDataset(
2438+
series=series,
2439+
past_covariates=pc,
2440+
future_covariates=fc,
2441+
n=n,
2442+
input_chunk_length=icl,
2443+
output_chunk_length=ocl,
2444+
use_static_covariates=True,
2445+
)
2446+
pred1 = model.predict(
2447+
n=n, series=series, past_covariates=pc, future_covariates=fc
2448+
)[0]
2449+
pred2 = model.predict_from_dataset(n=n, dataset=inference_dataset)[0]
2450+
pred3 = model.historical_forecasts(
2451+
forecast_horizon=n,
2452+
series=series,
2453+
past_covariates=pc,
2454+
future_covariates=fc,
2455+
retrain=False,
2456+
overlap_end=True,
2457+
)[0]
2458+
# extract only the last hist fc which should be the same as the regular predictions
2459+
pred3 = pred3[-1]
2460+
assert pred1 == pred2
2461+
np.testing.assert_array_almost_equal(pred3.all_values(), pred1.all_values())
2462+
assert pred3.time_index.equals(pred1.time_index)
2463+
assert pred3.static_covariates.equals(series[0].static_covariates)
2464+
2465+
def helper_predict_raise_on_missing_input(
2466+
self, model, fn: str, series, pc, fc, **kwargs
2467+
):
2468+
"""Helper function to test that the model raises an error when calling `predict()` or `historical_forecasts()`
2469+
after `fit_from_dataset()` with missing inputs."""
2470+
with pytest.raises(ValueError) as exc:
2471+
_ = getattr(model, fn)(series=series, **kwargs)
2472+
assert str(exc.value).startswith("The model was trained with past covariates.")
2473+
with pytest.raises(ValueError) as exc:
2474+
_ = getattr(model, fn)(series=series, past_covariates=pc, **kwargs)
2475+
assert str(exc.value).startswith(
2476+
"The model was trained with future covariates."
2477+
)
2478+
with pytest.raises(ValueError) as exc:
2479+
_ = getattr(model, fn)(series=series, future_covariates=fc, **kwargs)
2480+
assert str(exc.value).startswith("The model was trained with past covariates.")
2481+
with pytest.raises(ValueError) as exc:
2482+
_ = getattr(model, fn)(
2483+
series=[series[0].with_static_covariates(None)],
2484+
past_covariates=pc,
2485+
future_covariates=fc,
2486+
**kwargs,
2487+
)
2488+
assert str(exc.value).startswith(
2489+
"The model was trained with static covariates."
2490+
)
2491+
2492+
def helper_predict_from_ds_raise_on_missing_input(
2493+
self,
2494+
model,
2495+
series,
2496+
pc,
2497+
fc,
2498+
n,
2499+
**kwargs,
2500+
):
2501+
"""Helper function to test that the model raises an error when calling `predict_from_dataset()` after
2502+
`fit_from_dataset()` with missing inputs."""
2503+
inf_dataset = SequentialTorchInferenceDataset(
2504+
n=n,
2505+
series=series,
2506+
**kwargs,
2507+
)
2508+
with pytest.raises(ValueError) as exc:
2509+
_ = model.predict_from_dataset(n=n, dataset=inf_dataset)
2510+
assert str(exc.value).startswith(
2511+
"This model has been trained with `past_covariates`"
2512+
)
2513+
2514+
inf_dataset = SequentialTorchInferenceDataset(
2515+
n=n,
2516+
series=series,
2517+
past_covariates=pc,
2518+
**kwargs,
2519+
)
2520+
with pytest.raises(ValueError) as exc:
2521+
_ = model.predict_from_dataset(n=n, dataset=inf_dataset)
2522+
assert str(exc.value).startswith(
2523+
"This model has been trained with `historic_future_covariates`"
2524+
)
2525+
2526+
inf_dataset = SequentialTorchInferenceDataset(
2527+
n=n,
2528+
series=series,
2529+
future_covariates=fc,
2530+
**kwargs,
2531+
)
2532+
with pytest.raises(ValueError) as exc:
2533+
_ = model.predict_from_dataset(n=n, dataset=inf_dataset)
2534+
assert str(exc.value).startswith(
2535+
"This model has been trained with `past_covariates`"
2536+
)
2537+
2538+
inf_dataset = SequentialTorchInferenceDataset(
2539+
n=n,
2540+
series=series,
2541+
past_covariates=pc,
2542+
future_covariates=fc,
2543+
use_static_covariates=False,
2544+
**kwargs,
2545+
)
2546+
with pytest.raises(ValueError) as exc:
2547+
_ = model.predict_from_dataset(n=n, dataset=inf_dataset)
2548+
assert str(exc.value).startswith(
2549+
"This model has been trained with `static_covariates`"
2550+
)
2551+
23772552
def helper_equality_encoders(
23782553
self, first_encoders: dict[str, Any], second_encoders: dict[str, Any]
23792554
):

0 commit comments

Comments
 (0)