Skip to content

Commit 8523628

Browse files
authored
Feat/global hist fc retrain (#2916)
* add global hfc mode * update docs * add val length to global hfc * extend datatransformer support for hfc * integrate global hfc into regular hfc * fix logic * update docs * add first unit tests * add data transformer tests * extend tests * fix failing backtest * add tests for new drop before/after * add more tests * clean up for pr * update changelog * update changelog * update docs
1 parent 2274e94 commit 8523628

File tree

11 files changed

+960
-266
lines changed

11 files changed

+960
-266
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2323
- 🔴 Renamed the `RegressionEnsembleModel` ensemble model attribute from `regression_model` to `ensemble_model` to make it more clear that this model is used to combine the predictions of the base models. [#2894](https://github.com/unit8co/darts/pull/2894) by [Dennis Bader](https://github.com/dennisbader).
2424
- Added parameter `verbose` to `ForecastingModel.fit()` and `predict()` that allows to control the verbosity for model fitting and prediction. Ignored if the underlying model does not support it. [#2805](https://github.com/unit8co/darts/pull/2805) by [Timon Erhart](https://github.com/turbotimon) and [Dennis Bader](https://github.com/dennisbader).
2525
- It is now possible to control the fit and predict verbosity in `ForecastingModel.historical_forecasts()` by passing `verbose` in parameters `fit_kwargs` and `predict_kwargs`. [#2805](https://github.com/unit8co/darts/pull/2805) by [Timon Erhart](https://github.com/turbotimon) and [Dennis Bader](https://github.com/dennisbader).
26+
- Added support for applying historical forecast, backtest and residuals globally on all series with parameter `apply_globally: bool`. If `True`, computes the output only the time intersection of all series. Additionally, with `retrain=True`, activates global model- and data transformer fitting (for global forecasting models). If `False` (default), computes the output on the entire extent of each individual series and performs local fitting. [#2916](https://github.com/unit8co/darts/pull/2916) by [Dennis Bader](https://github.com/dennisbader).
27+
- `TimeSeries.drop_before()` and `drop_after()` now support keeping the split point in the returned series by passing parameter `keep=True`. [#2916](https://github.com/unit8co/darts/pull/2916) by [Dennis Bader](https://github.com/dennisbader).
2628

2729
**Fixed**
2830

darts/models/forecasting/conformal_models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from darts.utils import _build_tqdm_iterator, _with_sanity_checks
4242
from darts.utils.historical_forecasts.utils import (
4343
_adjust_historical_forecasts_time_index,
44+
_slice_intersect_series,
4445
)
4546
from darts.utils.timeseries_generation import _build_forecast_series
4647
from darts.utils.ts_utils import (
@@ -410,6 +411,7 @@ def historical_forecasts(
410411
start_format: Literal["position", "value"] = "value",
411412
stride: int = 1,
412413
retrain: Union[bool, int, Callable[..., bool]] = True,
414+
apply_globally: bool = False,
413415
overlap_end: bool = False,
414416
last_points_only: bool = True,
415417
verbose: bool = False,
@@ -503,6 +505,12 @@ def historical_forecasts(
503505
(set at model creation) and `>=cal_stride`.
504506
retrain
505507
Currently ignored by conformal models.
508+
apply_globally
509+
Whether to apply historical forecasts globally on the time intersection all series, or independently on
510+
each series. This includes global model- and data transformer fitting. Only really effective for global
511+
forecasting models, but can also be used with local models to generate forecasts on the same time frame. If
512+
`True`, considers only the time intersection of all series for historical forecasting. If `False`,
513+
considers the entire extent of each individual series for historical forecasting.
506514
overlap_end
507515
Whether the returned forecasts can go beyond the series' end or not.
508516
last_points_only
@@ -562,6 +570,15 @@ def historical_forecasts(
562570
past_covariates = series2seq(past_covariates)
563571
future_covariates = series2seq(future_covariates)
564572

573+
if apply_globally:
574+
# for global hfc, we have to slice intersect already here to compute the correct start points
575+
series, past_covariates, future_covariates, _ = _slice_intersect_series(
576+
series=series,
577+
past_covariates=past_covariates,
578+
future_covariates=future_covariates,
579+
sample_weight=None,
580+
)
581+
565582
# generate only the required forecasts (if `start` is given, we have to start earlier to satisfy the
566583
# calibration set requirements)
567584
cal_start, cal_start_format = _get_calibration_hfc_start(
@@ -583,6 +600,7 @@ def historical_forecasts(
583600
start_format=cal_start_format,
584601
stride=self.cal_stride,
585602
retrain=False,
603+
apply_globally=apply_globally,
586604
overlap_end=overlap_end,
587605
last_points_only=last_points_only,
588606
verbose=verbose,
@@ -631,6 +649,7 @@ def backtest(
631649
start_format: Literal["position", "value"] = "value",
632650
stride: int = 1,
633651
retrain: Union[bool, int, Callable[..., bool]] = True,
652+
apply_globally: bool = False,
634653
overlap_end: bool = False,
635654
last_points_only: bool = False,
636655
metric: Union[METRIC_TYPE, list[METRIC_TYPE]] = metrics.mape,
@@ -731,6 +750,12 @@ def backtest(
731750
The number of time steps between two consecutive predictions.
732751
retrain
733752
Currently ignored by conformal models.
753+
apply_globally
754+
Whether to apply historical forecasts globally on the time intersection all series, or independently on
755+
each series. This includes global model- and data transformer fitting. Only really effective for global
756+
forecasting models, but can also be used with local models to generate forecasts on the same time frame. If
757+
`True`, considers only the time intersection of all series for historical forecasting. If `False`,
758+
considers the entire extent of each individual series for historical forecasting.
734759
overlap_end
735760
Whether the returned forecasts can go beyond the series' end or not.
736761
last_points_only
@@ -821,6 +846,7 @@ def backtest(
821846
start_format=start_format,
822847
stride=stride,
823848
retrain=retrain,
849+
apply_globally=apply_globally,
824850
overlap_end=overlap_end,
825851
last_points_only=last_points_only,
826852
metric=metric,
@@ -853,6 +879,7 @@ def residuals(
853879
start_format: Literal["position", "value"] = "value",
854880
stride: int = 1,
855881
retrain: Union[bool, int, Callable[..., bool]] = True,
882+
apply_globally: bool = False,
856883
overlap_end: bool = False,
857884
last_points_only: bool = True,
858885
metric: METRIC_TYPE = metrics.err,
@@ -963,6 +990,12 @@ def residuals(
963990
The number of time steps between two consecutive predictions.
964991
retrain
965992
Currently ignored by conformal models.
993+
apply_globally
994+
Whether to apply historical forecasts globally on the time intersection all series, or independently on
995+
each series. This includes global model- and data transformer fitting. Only really effective for global
996+
forecasting models, but can also be used with local models to generate forecasts on the same time frame. If
997+
`True`, considers only the time intersection of all series for historical forecasting. If `False`,
998+
considers the entire extent of each individual series for historical forecasting.
966999
overlap_end
9671000
Whether the returned forecasts can go beyond the series' end or not.
9681001
last_points_only
@@ -1039,6 +1072,7 @@ def residuals(
10391072
start_format=start_format,
10401073
stride=stride,
10411074
retrain=retrain,
1075+
apply_globally=apply_globally,
10421076
overlap_end=overlap_end,
10431077
last_points_only=last_points_only,
10441078
metric=metric,

0 commit comments

Comments
 (0)