Skip to content

Commit 2391d8b

Browse files
authored
Fix/sklearn model quantile reg (#2838)
* improve likelihood handling for sklearn models * update tests * update changelog
1 parent bd53fe0 commit 2391d8b

File tree

8 files changed

+157
-79
lines changed

8 files changed

+157
-79
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1313

1414
**Fixed**
1515

16+
- Fixed a bug in `SKLearnModel.get_estimator()` for univariate quantile models that use `multi_models=False` , where using `quantile` did not return the correct fitted quantile model / estimator. [#2838](https://github.com/unit8co/darts/pull/2838) by [Dennis Bader](https://github.com/dennisbader).
17+
1618
**Dependencies**
1719

1820
### For developers of the library:

darts/models/forecasting/catboost_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def encode_year(idx):
192192
"""
193193
kwargs["random_state"] = random_state # seed for tree learner
194194
self.kwargs = kwargs
195-
self._model_container = None
196195

197196
# parse likelihood
198197
if likelihood is not None:
@@ -207,12 +206,10 @@ def encode_year(idx):
207206
# RMSEWithUncertainty returns mean and variance which is equivalent to gaussian
208207
likelihood = "gaussian"
209208

210-
if likelihood == "quantile":
211-
self._model_container = _QuantileModelContainer()
212-
else:
209+
if likelihood != "quantile":
213210
self.kwargs["loss_function"] = likelihood_map[likelihood]
214211

215-
self._likelihood = _get_likelihood(
212+
likelihood = _get_likelihood(
216213
likelihood=likelihood,
217214
n_outputs=output_chunk_length if multi_models else 1,
218215
quantiles=quantiles,
@@ -241,6 +238,10 @@ def encode_year(idx):
241238
# if no loss provided, get the default loss from the model
242239
self.kwargs["loss_function"] = self.model.get_params().get("loss_function")
243240

241+
self._likelihood = likelihood
242+
if isinstance(likelihood, QuantileRegression):
243+
self._model_container = _QuantileModelContainer()
244+
244245
def fit(
245246
self,
246247
series: Union[TimeSeries, Sequence[TimeSeries]],
@@ -325,6 +326,7 @@ def fit(
325326
verbose=verbose,
326327
**kwargs,
327328
)
329+
# store the trained model in the container as it might have been wrapped by MultiOutputRegressor
328330
self._model_container[quantile] = self.model
329331
return self
330332

darts/models/forecasting/lgbm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,13 @@ def encode_year(idx):
190190
"""
191191
kwargs["random_state"] = random_state # seed for tree learner
192192
self.kwargs = kwargs
193-
self._model_container = None
194193

195194
# parse likelihood
196195
if likelihood is not None:
197196
_check_likelihood(likelihood, ["quantile", "poisson"])
198197
self.kwargs["objective"] = likelihood
199-
if likelihood == "quantile":
200-
self._model_container = _QuantileModelContainer()
201198

202-
self._likelihood = _get_likelihood(
199+
likelihood = _get_likelihood(
203200
likelihood=likelihood,
204201
n_outputs=output_chunk_length if multi_models else 1,
205202
quantiles=quantiles,
@@ -221,6 +218,10 @@ def encode_year(idx):
221218
random_state=random_state,
222219
)
223220

221+
self._likelihood = likelihood
222+
if isinstance(likelihood, QuantileRegression):
223+
self._model_container = _QuantileModelContainer()
224+
224225
def fit(
225226
self,
226227
series: Union[TimeSeries, Sequence[TimeSeries]],
@@ -299,6 +300,7 @@ def fit(
299300
val_sample_weight=val_sample_weight,
300301
**kwargs,
301302
)
303+
# store the trained model in the container as it might have been wrapped by MultiOutputRegressor
302304
self._model_container[quantile] = self.model
303305
return self
304306

darts/models/forecasting/linear_regression_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ def encode_year(idx):
173173
[1005.81830675]])
174174
"""
175175
self.kwargs = kwargs
176-
self._model_container = None
177176

178177
# parse likelihood
179178
if likelihood is not None:
@@ -182,11 +181,10 @@ def encode_year(idx):
182181
model = PoissonRegressor(**kwargs)
183182
if likelihood == "quantile":
184183
model = QuantileRegressor(**kwargs)
185-
self._model_container = _QuantileModelContainer()
186184
else:
187185
model = LinearRegression(**kwargs)
188186

189-
self._likelihood = _get_likelihood(
187+
likelihood = _get_likelihood(
190188
likelihood=likelihood,
191189
n_outputs=output_chunk_length if multi_models else 1,
192190
quantiles=quantiles,
@@ -205,6 +203,10 @@ def encode_year(idx):
205203
random_state=random_state,
206204
)
207205

206+
self._likelihood = likelihood
207+
if isinstance(likelihood, QuantileRegression):
208+
self._model_container = _QuantileModelContainer()
209+
208210
def fit(
209211
self,
210212
series: Union[TimeSeries, Sequence[TimeSeries]],
@@ -249,7 +251,7 @@ def fit(
249251
sample_weight=sample_weight,
250252
**kwargs,
251253
)
252-
254+
# store the trained model in the container as it might have been wrapped by MultiOutputRegressor
253255
self._model_container[quantile] = self.model
254256

255257
# replace the last trained QuantileRegressor with the dictionary of Regressors.

darts/models/forecasting/sklearn_model.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ def encode_year(idx):
227227
self._lagged_feature_names: Optional[list[str]] = None
228228
self._lagged_label_names: Optional[list[str]] = None
229229

230+
# optionally, the model can be wrapped in a likelihood model
231+
self._likelihood: Optional[SKLearnLikelihood] = None
232+
# for quantile likelihood models, the model container is a dict of quantile -> model
233+
self._model_container: Optional[dict[float, Any]] = None
234+
230235
# check and set output_chunk_length
231236
raise_if_not(
232237
isinstance(output_chunk_length, int) and output_chunk_length > 0,
@@ -525,9 +530,11 @@ def get_estimator(
525530
):
526531
"""Returns the estimator that forecasts the `horizon`th step of the `target_dim`th target component.
527532
528-
For probabilistic models fitting quantiles, it is possible to also specify the quantile.
533+
For probabilistic models fitting quantiles, a desired `quantile` can also be passed. If not passed, it will
534+
return the model predicting the median (quantile=0.5).
529535
530-
The model is returned directly if it supports multi-output natively.
536+
If the (quantile) model supports multi-output natively, it will return the model that can predict the entire
537+
horizon and all target components jointly.
531538
532539
Note: Internally, estimators are grouped by `output_chunk_length` position, then by component. For probabilistic
533540
models fitting quantiles, there is an additional abstraction layer, grouping the estimators by `quantile`.
@@ -539,13 +546,38 @@ def get_estimator(
539546
target_dim
540547
The index of the target component.
541548
quantile
542-
Optionally, for probabilistic model with `likelihood="quantile"`, a quantile value.
549+
Optionally, for probabilistic model with `likelihood="quantile"`, the desired quantile value. If `None` and
550+
`likelihood="quantile"`, returns the model predicting the median (quantile=0.5).
543551
"""
544-
if not isinstance(self.model, MultiOutputRegressor):
552+
likelihood = self.likelihood
553+
if isinstance(likelihood, QuantileRegression):
554+
# for quantile-models, the estimators are grouped by quantiles
555+
if quantile is None:
556+
quantile = likelihood.quantiles[likelihood._median_idx]
557+
elif quantile not in self._model_container:
558+
raise_log(
559+
ValueError(
560+
f"Invalid `quantile={quantile}`. Must be one of the fitted quantiles "
561+
f"`{list(self._model_container.keys())}`."
562+
),
563+
logger,
564+
)
565+
model = self._model_container[quantile]
566+
elif quantile is not None:
567+
raise_log(
568+
ValueError(
569+
"`quantile` is only supported for probabilistic models that use `likelihood='quantile'`."
570+
),
571+
logger=logger,
572+
)
573+
else:
574+
model = self.model
575+
576+
if not isinstance(model, MultiOutputRegressor):
545577
logger.warning(
546578
"Model supports multi-output; a single estimator forecasts all the horizons and components."
547579
)
548-
return self.model
580+
return model
549581

550582
if not 0 <= horizon < self.output_chunk_length:
551583
raise_log(
@@ -566,27 +598,7 @@ def get_estimator(
566598
idx_estimator = (
567599
self.multi_models * self.input_dim["target"] * horizon + target_dim
568600
)
569-
if quantile is None:
570-
return self.model.estimators_[idx_estimator]
571-
572-
# for quantile-models, the estimators are also grouped by quantiles
573-
if not isinstance(self.likelihood, QuantileRegression):
574-
raise_log(
575-
ValueError(
576-
"`quantile` is only supported for probabilistic models that "
577-
"use `likelihood='quantile'`."
578-
),
579-
logger,
580-
)
581-
if quantile not in self._model_container:
582-
raise_log(
583-
ValueError(
584-
f"Invalid `quantile={quantile}`. Must be one of the fitted quantiles "
585-
f"`{list(self._model_container.keys())}`."
586-
),
587-
logger,
588-
)
589-
return self._model_container[quantile].estimators_[idx_estimator]
601+
return model.estimators_[idx_estimator]
590602

591603
def _add_val_set_to_kwargs(
592604
self,

darts/models/forecasting/xgboost.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def encode_year(idx):
191191
"""
192192
kwargs["random_state"] = random_state # seed for tree learner
193193
self.kwargs = kwargs
194-
self._model_container = None
195194

196195
# parse likelihood
197196
if likelihood is not None:
@@ -200,9 +199,8 @@ def encode_year(idx):
200199
self.kwargs["objective"] = f"count:{likelihood}"
201200
elif likelihood == "quantile":
202201
self.kwargs["objective"] = "reg:quantileerror"
203-
self._model_container = _QuantileModelContainer()
204202

205-
self._likelihood = _get_likelihood(
203+
likelihood = _get_likelihood(
206204
likelihood=likelihood,
207205
n_outputs=output_chunk_length if multi_models else 1,
208206
quantiles=quantiles,
@@ -221,6 +219,10 @@ def encode_year(idx):
221219
random_state=random_state,
222220
)
223221

222+
self._likelihood = likelihood
223+
if isinstance(likelihood, QuantileRegression):
224+
self._model_container = _QuantileModelContainer()
225+
224226
def fit(
225227
self,
226228
series: Union[TimeSeries, Sequence[TimeSeries]],
@@ -301,6 +303,7 @@ def fit(
301303
val_sample_weight=val_sample_weight,
302304
**kwargs,
303305
)
306+
# store the trained model in the container as it might have been wrapped by MultiOutputRegressor
304307
self._model_container[quantile] = self.model
305308
return self
306309

0 commit comments

Comments
 (0)