Skip to content

Commit 193bb3b

Browse files
authored
Feat/exp smoothing add argument (#2904)
* feat: add extra arguments for Exponential Smoothing model * changelog update * doc changes and PR suggestions
1 parent b870cb6 commit 193bb3b

File tree

4 files changed

+52
-2
lines changed

4 files changed

+52
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1616
- `from_group_dataframe()` now supports creating `TimeSeries` from **additional DataFrame backends** (Polars, PyArrow, ...). We leverage `narwhals` as the compatibility layer between DataFrame libraries. See their [documentation](https://narwhals-dev.github.io/narwhals/) for all supported backends. [#2766](https://github.com/unit8co/darts/pull/2766) by [He Weilin](https://github.com/cnhwl).
1717
- Added `add_regressor_configs` parameter to the `Prophet` model, enabling component-specific control over `prior_scale`, `mode`, and `standardize` for the future covariates. [#2882](https://github.com/unit8co/darts/issues/2882) by [Ramsay Davis](https://github.com/RamsayDavisWL).
1818
- 🔴 Increased the decimal places for quantile component names from 2 to 3 for more precise quantiles. (e.g. `component_name_q0.500` for quantile 0.5). This affects quantile forecasts as well as quantiles computed with `TimeSeries.quantile()`. [#2887](https://github.com/unit8co/darts/pull/2786) by [He Weilin](https://github.com/cnhwl).
19+
- Added model creation parameters `random_errors` and `error` to `ExponentialSmoothing` that give control over how probabilistic forecasts are generated. [#2290491](https://github.com/unit8co/darts/pull/2904) by [Jakub Chłapek](https://github.com/jakubchlapek)
1920

2021
**Fixed**
2122

darts/models/forecasting/exponential_smoothing.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__(
2525
damped: Optional[bool] = False,
2626
seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE,
2727
seasonal_periods: Optional[int] = None,
28+
error: Optional[str] = "add",
29+
random_errors: Optional[Any] = None,
2830
random_state: Optional[int] = None,
2931
kwargs: Optional[dict[str, Any]] = None,
3032
**fit_kwargs,
@@ -63,6 +65,17 @@ def __init__(
6365
seasonal_periods
6466
The number of periods in a complete seasonal cycle, e.g., 4 for quarterly data or 7 for daily
6567
data with a weekly cycle. If not set, inferred from frequency of the series.
68+
error
69+
Specifies the type of error model for state space formulation to use when using predict()
70+
with ``num_samples > 1``. Default is `"add"`.
71+
Will be passed to statsmodels' :func:`simulate()` method. See the documentation `here
72+
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.HoltWintersResults.simulate.html>`_
73+
for more information.
74+
random_errors
75+
Specifies how the random errors should be obtained, when using predict() with ``num_samples > 1``.
76+
Will be passed to statsmodels' :func:`simulate()` method. See the documentation `here
77+
<https://www.statsmodels.org/stable/generated/statsmodels.tsa.holtwinters.HoltWintersResults.simulate.html>`_
78+
for more information.
6679
random_state
6780
Controls the randomness for reproducible forecasting.
6881
kwargs
@@ -100,6 +113,8 @@ def __init__(
100113
self.seasonal = seasonal
101114
self.infer_seasonal_periods = seasonal_periods is None
102115
self.seasonal_periods = seasonal_periods
116+
self.error = error
117+
self.random_errors = random_errors
103118
self.constructor_kwargs = dict() if kwargs is None else kwargs
104119
self.fit_kwargs = fit_kwargs
105120
self.model = None
@@ -156,7 +171,13 @@ def predict(
156171
rng = check_random_state(random_state)
157172

158173
forecast = np.expand_dims(
159-
self.model.simulate(n, repetitions=num_samples, random_state=rng),
174+
self.model.simulate(
175+
n,
176+
repetitions=num_samples,
177+
random_state=rng,
178+
random_errors=self.random_errors,
179+
error=self.error,
180+
),
160181
axis=1,
161182
)
162183

darts/tests/models/forecasting/test_exponential_smoothing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,31 @@ def test_fit_kwargs(self):
9191
# forecasts should be slightly different
9292
assert pred.time_index.equals(pred_ls.time_index)
9393
assert all(np.not_equal(pred.values(), pred_ls.values()))
94+
95+
def test_random_errors(self):
96+
"""Test whether random_errors parameter is correctly passed to simulate()"""
97+
series = tg.sine_timeseries(length=100, freq="H")
98+
model = ExponentialSmoothing(random_state=42)
99+
model.fit(series)
100+
pred = model.predict(n=10, num_samples=10, random_state=42)
101+
102+
model_boot = ExponentialSmoothing(random_errors="bootstrap", random_state=42)
103+
model_boot.fit(series)
104+
pred_boot = model_boot.predict(n=10, num_samples=10, random_state=42)
105+
106+
# methods with different random_errors set should yield different forecasts
107+
assert not np.allclose(pred.values(), pred_boot.values(), atol=1e-5)
108+
109+
def test_error(self):
110+
"""Test whether error parameter is correctly passed to simulate()"""
111+
series = tg.sine_timeseries(length=100, freq="H")
112+
model = ExponentialSmoothing(random_state=42)
113+
model.fit(series)
114+
pred = model.predict(n=10, num_samples=10, random_state=42)
115+
116+
model_boot = ExponentialSmoothing(error="mul", random_state=42)
117+
model_boot.fit(series)
118+
pred_boot = model_boot.predict(n=10, num_samples=10, random_state=42)
119+
120+
# methods with different error set should yield different forecasts
121+
assert not np.allclose(pred.values(), pred_boot.values(), atol=1e-5)

darts/tests/models/forecasting/test_local_forecasting_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def test_model_str_call(self, config):
665665
(
666666
ExponentialSmoothing(),
667667
"ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, "
668-
+ "seasonal_periods=None, random_state=None, kwargs=None)",
668+
+ "seasonal_periods=None, error=add, random_errors=None, random_state=None, kwargs=None)",
669669
), # no params changed
670670
(
671671
ARIMA(1, 1, 1),

0 commit comments

Comments
 (0)