Skip to content

Commit 08bcfb4

Browse files
Implementation of StatsForecastAutoMFLES (#2747)
* Implementation of StatsForecastAutoMFLES * Added StatsForecastAutoMFLES entry to CHANGELOG * Added StatsForecastAutoMFLES entry to README * Added StatsForecastAutoMFLES entry to covariates * minor updates --------- Co-authored-by: dennisbader <[email protected]>
1 parent d946745 commit 08bcfb4

File tree

7 files changed

+144
-1
lines changed

7 files changed

+144
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1212
**Improved**
1313

1414
- Added support for categorical covariate to `CatBoostModel`. You can now define categorical components at model construction with parameters `categorical_*_covariates: List[str]` for past, future, and static covariates. [#2733](https://github.com/unit8co/darts/pull/2750) by [Jonas Blanc](https://github.com/jonasblanc).
15+
- Added new forecasting model: `StatsForecastAutoMFLES`, a simple time series method based on gradient boosting time series decomposition as proposed in [this repository](https://github.com/tblume1992/MFLES). This implementation is based on [AutoMFLES](https://nixtlaverse.nixtla.io/statsforecast/docs/models/mfles.html) from Nixtla's `statsforecasts` library. [#2747](https://github.com/unit8co/darts/pull/2747) by [Che Hang Ng](https://github.com/CheHangNg).
1516

1617
**Removed / moved**
1718

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ on bringing more models and features.
237237
| [ExponentialSmoothing](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.exponential_smoothing.html#darts.models.forecasting.exponential_smoothing.ExponentialSmoothing) | | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
238238
| [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 |
239239
| [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 |
240+
| [StatsForecastAutoMFLES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_mfles.html#darts.models.forecasting.sf_auto_mfles.StatsForecastAutoMFLES) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 ✅ 🔴 | 🔴 🔴 | 🔴 |
240241
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | [TBATS paper](https://robjhyndman.com/papers/ComplexSeasonality.pdf) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
241242
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
242243
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | [Theta](https://robjhyndman.com/papers/Theta.pdf) & [4 Theta](https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 |

darts/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
from darts.models.forecasting.sf_auto_arima import StatsForecastAutoARIMA
9494
from darts.models.forecasting.sf_auto_ces import StatsForecastAutoCES
9595
from darts.models.forecasting.sf_auto_ets import StatsForecastAutoETS
96+
from darts.models.forecasting.sf_auto_mfles import StatsForecastAutoMFLES
9697
from darts.models.forecasting.sf_auto_tbats import StatsForecastAutoTBATS
9798
from darts.models.forecasting.sf_auto_theta import StatsForecastAutoTheta
9899

@@ -107,6 +108,7 @@
107108
StatsForecastAutoARIMA = NotImportedModule(module_name="StatsForecast", warn=False)
108109
StatsForecastAutoCES = NotImportedModule(module_name="StatsForecast", warn=False)
109110
StatsForecastAutoETS = NotImportedModule(module_name="StatsForecast", warn=False)
111+
StatsForecastAutoMFLES = NotImportedModule(module_name="StatsForecast", warn=False)
110112
StatsForecastAutoTheta = NotImportedModule(module_name="StatsForecast", warn=False)
111113
StatsForecastAutoTBATS = NotImportedModule(module_name="StatsForecast", warn=False)
112114

@@ -159,6 +161,7 @@
159161
"StatsForecastAutoARIMA",
160162
"StatsForecastAutoCES",
161163
"StatsForecastAutoETS",
164+
"StatsForecastAutoMFLES",
162165
"StatsForecastAutoTheta",
163166
"StatsForecastAutoTBATS",
164167
"XGBModel",

darts/models/forecasting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- :class:`~darts.models.forecasting.exponential_smoothing.ExponentialSmoothing`
1919
- :class:`~darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS`
2020
- :class:`~darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES`
21+
- :class:`~darts.models.forecasting.sf_auto_mfles.StatsForecastAutoMFLES`
2122
- :class:`~darts.models.forecasting.tbats_model.BATS`
2223
- :class:`~darts.models.forecasting.tbats_model.TBATS`
2324
- :class:`~darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS`
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
StatsForecastAutoMFLES
3+
-----------
4+
"""
5+
6+
from typing import Optional
7+
8+
from statsforecast.models import AutoMFLES as SFAutoMFLES
9+
10+
from darts import TimeSeries
11+
from darts.logging import get_logger
12+
from darts.models.forecasting.forecasting_model import (
13+
FutureCovariatesLocalForecastingModel,
14+
)
15+
16+
logger = get_logger(__name__)
17+
18+
19+
class StatsForecastAutoMFLES(FutureCovariatesLocalForecastingModel):
20+
def __init__(
21+
self, *autoMFLES_args, add_encoders: Optional[dict] = None, **autoMFLES_kwargs
22+
):
23+
"""Auto-MFLES based on `Statsforecasts package
24+
<https://github.com/Nixtla/statsforecast>`_.
25+
26+
Automatically selects the best MFLES model from all feasible combinations of the parameters
27+
`seasonality_weights`, `smoother`, `ma`, and `seasonal_period`. Selection is made using the sMAPE by default.
28+
29+
We refer to the `statsforecast AutoMFLES documentation
30+
<https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#mfles>`_
31+
for the exhaustive documentation of the arguments.
32+
33+
Parameters
34+
----------
35+
autoMFLES_args
36+
Positional arguments for ``statsforecasts.models.AutoMFLES``.
37+
add_encoders
38+
A large number of future covariates can be automatically generated with `add_encoders`.
39+
This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
40+
will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to
41+
transform the generated covariates. This happens all under one hood and only needs to be specified at
42+
model creation.
43+
Read :meth:`SequentialEncoder <darts.dataprocessing.encoders.SequentialEncoder>` to find out more about
44+
``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features:
45+
46+
.. highlight:: python
47+
.. code-block:: python
48+
49+
def encode_year(idx):
50+
return (idx.year - 1950) / 50
51+
52+
add_encoders={
53+
'cyclic': {'future': ['month']},
54+
'datetime_attribute': {'future': ['hour', 'dayofweek']},
55+
'position': {'future': ['relative']},
56+
'custom': {'future': [encode_year]},
57+
'transformer': Scaler(),
58+
'tz': 'CET'
59+
}
60+
..
61+
autoMFLES_kwargs
62+
Keyword arguments for ``statsforecasts.models.AutoMFLES``.
63+
64+
Examples
65+
--------
66+
>>> from darts.datasets import AirPassengersDataset
67+
>>> from darts.models import StatsForecastAutoMFLES
68+
>>> from darts.utils.timeseries_generation import datetime_attribute_timeseries
69+
>>> series = AirPassengersDataset().load()
70+
>>> # optionally, use some future covariates; e.g. the value of the month encoded as a sine and cosine series
71+
>>> future_cov = datetime_attribute_timeseries(series, "month", cyclic=True, add_length=6)
72+
>>> # define StatsForecastAutoMFLES parameters
73+
>>> model = StatsForecastAutoMFLES(season_length=12, test_size=12)
74+
>>> model.fit(series, future_covariates=future_cov)
75+
>>> pred = model.predict(6, future_covariates=future_cov)
76+
>>> pred.values()
77+
array([[466.03298745],
78+
[450.76192105],
79+
[517.6342497 ],
80+
[511.62988828],
81+
[520.15305998],
82+
[593.38690019]])
83+
"""
84+
if "prediction_intervals" in autoMFLES_kwargs:
85+
logger.warning(
86+
"StatsForecastAutoMFLES does not support probabilistic forecasting. "
87+
"`prediction_intervals` will be ignored."
88+
)
89+
90+
super().__init__(add_encoders=add_encoders)
91+
self.model = SFAutoMFLES(*autoMFLES_args, **autoMFLES_kwargs)
92+
93+
def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
94+
super()._fit(series, future_covariates)
95+
self._assert_univariate(series)
96+
series = self.training_series
97+
self.model.fit(
98+
series.values(copy=False).flatten(),
99+
X=future_covariates.values(copy=False) if future_covariates else None,
100+
)
101+
return self
102+
103+
def _predict(
104+
self,
105+
n: int,
106+
future_covariates: Optional[TimeSeries] = None,
107+
num_samples: int = 1,
108+
verbose: bool = False,
109+
):
110+
super()._predict(n, future_covariates, num_samples)
111+
forecast_dict = self.model.predict(
112+
h=n,
113+
X=future_covariates.values(copy=False) if future_covariates else None,
114+
level=None,
115+
)
116+
117+
return self._build_forecast_series(forecast_dict["mean"])
118+
119+
@property
120+
def supports_multivariate(self) -> bool:
121+
return False
122+
123+
@property
124+
def min_train_series_length(self) -> int:
125+
return 10
126+
127+
@property
128+
def _supports_range_index(self) -> bool:
129+
return True
130+
131+
@property
132+
def supports_probabilistic_prediction(self) -> bool:
133+
return False

darts/tests/models/forecasting/test_local_forecasting_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
StatsForecastAutoARIMA,
3434
StatsForecastAutoCES,
3535
StatsForecastAutoETS,
36+
StatsForecastAutoMFLES,
3637
StatsForecastAutoTBATS,
3738
StatsForecastAutoTheta,
3839
Theta,
@@ -62,6 +63,7 @@
6263
(StatsForecastAutoTheta(season_length=12), 5.5),
6364
(StatsForecastAutoCES(season_length=12, model="Z"), 7.3),
6465
(StatsForecastAutoETS(season_length=12, model="AAZ"), 7.3),
66+
(StatsForecastAutoMFLES(season_length=12, test_size=12), 9.8),
6567
(StatsForecastAutoTBATS(season_length=12), 10),
6668
(Croston(version="classic"), 23),
6769
(Croston(version="tsb", alpha_d=0.1, alpha_p=0.1), 23),
@@ -95,6 +97,7 @@
9597
dual_models = [
9698
ARIMA(),
9799
StatsForecastAutoARIMA(season_length=12),
100+
StatsForecastAutoMFLES(season_length=12, test_size=12),
98101
StatsForecastAutoETS(season_length=12),
99102
]
100103

docs/userguide/covariates.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,9 @@ GFMs are models that can be trained on multiple target (and covariate) time seri
131131
| [ExponentialSmoothing](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.exponential_smoothing.html#darts.models.forecasting.exponential_smoothing.ExponentialSmoothing) | | | |
132132
| [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | || |
133133
| [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | | | |
134+
| [StatsForecastAutoMFLES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_mfles.html#darts.models.forecasting.sf_auto_mfles.StatsForecastAutoMFLES) | || |
134135
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | | | |
135-
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | | | |
136+
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | | | |
136137
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | | | |
137138
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | | | |
138139
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | || |

0 commit comments

Comments
 (0)