Skip to content

Commit ad93612

Browse files
authored
Feat/hist fc start stride (#2560)
* improve hist fc start point * add tests * update documentation * update changelog * clean up code * fix tests * fix missed lines * improve codecov
1 parent c116405 commit ad93612

File tree

10 files changed

+723
-177
lines changed

10 files changed

+723
-177
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- Improvements to `ForecastingModel`: Improved `start` handling for historical forecasts, backtest, residuals, and gridsearch. If `start` is not within the trainable / forecastable points, uses the closest valid start point that is a round multiple of `stride` ahead of start. Raises a ValueError, if no valid start point exists. This guarantees that all historical forecasts are `n * stride` points away from start, and will simplify many downstream tasks. [#2560](https://github.com/unit8co/darts/issues/2560) by [Dennis Bader](https://github.com/dennisbader).
15+
1416
**Fixed**
1517

1618
- Fixed a bug when using `darts.utils.data.tabularization.create_lagged_component_names()` with target `lags=None`, that did not return any lagged target label component names. [#2576](https://github.com/unit8co/darts/pull/2576) by [Dennis Bader](https://github.com/dennisbader).

darts/models/forecasting/forecasting_model.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,12 @@ def historical_forecasts(
706706
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
707707
- the first trainable point (given `train_length`) otherwise
708708
709+
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
710+
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
709711
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
710-
shifted by `output_chunk_shift` points into the future.
711-
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
712+
shifted by `output_chunk_shift` points into the future.
712713
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
713-
(default behavior with ``None``) and start at the first trainable/predictable point.
714+
(default behavior with ``None``) and start at the first trainable/predictable point.
714715
start_format
715716
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
716717
`pd.RangeIndex`.
@@ -1018,6 +1019,7 @@ def retrain_func(
10181019
historical_forecasts_time_index=historical_forecasts_time_index,
10191020
start=start,
10201021
start_format=start_format,
1022+
stride=stride,
10211023
show_warnings=show_warnings,
10221024
)
10231025

@@ -1267,9 +1269,12 @@ def backtest(
12671269
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
12681270
- the first trainable point (given `train_length`) otherwise
12691271
1270-
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
1272+
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
1273+
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
1274+
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
1275+
shifted by `output_chunk_shift` points into the future.
12711276
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
1272-
(default behavior with ``None``) and start at the first trainable/predictable point.
1277+
(default behavior with ``None``) and start at the first trainable/predictable point.
12731278
start_format
12741279
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
12751280
`pd.RangeIndex`.
@@ -1628,9 +1633,12 @@ def gridsearch(
16281633
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
16291634
- the first trainable point (given `train_length`) otherwise
16301635
1631-
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
1636+
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
1637+
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
1638+
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
1639+
shifted by `output_chunk_shift` points into the future.
16321640
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
1633-
(default behavior with ``None``) and start at the first trainable/predictable point.
1641+
(default behavior with ``None``) and start at the first trainable/predictable point.
16341642
start_format
16351643
Only used in expanding window mode. Defines the `start` format. Only effective when `start` is an integer
16361644
and `series` is indexed with a `pd.RangeIndex`.
@@ -1924,9 +1932,12 @@ def residuals(
19241932
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
19251933
- the first trainable point (given `train_length`) otherwise
19261934
1927-
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
1935+
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
1936+
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
1937+
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
1938+
shifted by `output_chunk_shift` points into the future.
19281939
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
1929-
(default behavior with ``None``) and start at the first trainable/predictable point.
1940+
(default behavior with ``None``) and start at the first trainable/predictable point.
19301941
start_format
19311942
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
19321943
`pd.RangeIndex`.

darts/tests/models/forecasting/test_backtesting.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import logging
23
import random
34
from itertools import product
45

@@ -733,8 +734,7 @@ def test_backtest_multiple_series(self):
733734
assert round(abs(error[0] - expected[0]), 4) == 0
734735
assert round(abs(error[1] - expected[1]), 4) == 0
735736

736-
@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
737-
def test_backtest_regression(self):
737+
def test_backtest_regression(self, caplog):
738738
np.random.seed(4)
739739

740740
gaussian_series = gt(mean=2, length=50)
@@ -804,13 +804,26 @@ def test_backtest_regression(self):
804804
assert score > 0.9
805805

806806
# Using a too small start value
807-
with pytest.raises(ValueError):
808-
RandomForest(lags=12).backtest(series=target, start=0, forecast_horizon=3)
807+
warning_expected = (
808+
"`start` position `{0}` corresponding to time `{1}` is before the first "
809+
"predictable/trainable historical forecasting point for series at index: 0. Using the first historical "
810+
"forecasting point `2000-01-15 00:00:00` that lies a round-multiple of `stride=1` ahead of `start`. "
811+
"To hide these warnings, set `show_warnings=False`."
812+
)
813+
caplog.clear()
814+
with caplog.at_level(logging.WARNING):
815+
_ = RandomForest(lags=12).backtest(
816+
series=target, start=0, forecast_horizon=3
817+
)
818+
assert warning_expected.format(0, target.start_time()) in caplog.text
819+
caplog.clear()
809820

810-
with pytest.raises(ValueError):
811-
RandomForest(lags=12).backtest(
821+
with caplog.at_level(logging.WARNING):
822+
_ = RandomForest(lags=12).backtest(
812823
series=target, start=0.01, forecast_horizon=3
813824
)
825+
assert warning_expected.format(0.01, target.start_time()) in caplog.text
826+
caplog.clear()
814827

815828
# Using RandomForest's start default value
816829
score = RandomForest(lags=12, random_state=0).backtest(
@@ -939,7 +952,6 @@ def test_gridsearch_metric_score(self):
939952

940953
assert score == recalculated_score, "The metric scores should match"
941954

942-
@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
943955
def test_gridsearch_random_search(self):
944956
np.random.seed(1)
945957

@@ -958,7 +970,6 @@ def test_gridsearch_random_search(self):
958970
assert isinstance(result[2], float)
959971
assert min(param_range) <= result[1]["lags"] <= max(param_range)
960972

961-
@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
962973
def test_gridsearch_n_random_samples_bad_arguments(self):
963974
dummy_series = get_dummy_series(ts_length=50)
964975

@@ -981,7 +992,6 @@ def test_gridsearch_n_random_samples_bad_arguments(self):
981992
params, dummy_series, forecast_horizon=1, n_random_samples=1.5
982993
)
983994

984-
@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
985995
def test_gridsearch_n_random_samples(self):
986996
np.random.seed(1)
987997

0 commit comments

Comments
 (0)