@@ -227,6 +227,11 @@ def encode_year(idx):
227227 self ._lagged_feature_names : Optional [list [str ]] = None
228228 self ._lagged_label_names : Optional [list [str ]] = None
229229
230+ # optionally, the model can be wrapped in a likelihood model
231+ self ._likelihood : Optional [SKLearnLikelihood ] = None
232+ # for quantile likelihood models, the model container is a dict of quantile -> model
233+ self ._model_container : Optional [dict [float , Any ]] = None
234+
230235 # check and set output_chunk_length
231236 raise_if_not (
232237 isinstance (output_chunk_length , int ) and output_chunk_length > 0 ,
@@ -525,9 +530,11 @@ def get_estimator(
525530 ):
526531 """Returns the estimator that forecasts the `horizon`th step of the `target_dim`th target component.
527532
528- For probabilistic models fitting quantiles, it is possible to also specify the quantile.
533+ For probabilistic models fitting quantiles, a desired `quantile` can also be passed. If not passed, it will
534+ return the model predicting the median (quantile=0.5).
529535
530- The model is returned directly if it supports multi-output natively.
536+ If the (quantile) model supports multi-output natively, it will return the model that can predict the entire
537+ horizon and all target components jointly.
531538
532539 Note: Internally, estimators are grouped by `output_chunk_length` position, then by component. For probabilistic
533540 models fitting quantiles, there is an additional abstraction layer, grouping the estimators by `quantile`.
@@ -539,13 +546,38 @@ def get_estimator(
539546 target_dim
540547 The index of the target component.
541548 quantile
542- Optionally, for probabilistic model with `likelihood="quantile"`, a quantile value.
549+ Optionally, for probabilistic model with `likelihood="quantile"`, the desired quantile value. If `None` and
550+ `likelihood="quantile"`, returns the model predicting the median (quantile=0.5).
543551 """
544- if not isinstance (self .model , MultiOutputRegressor ):
552+ likelihood = self .likelihood
553+ if isinstance (likelihood , QuantileRegression ):
554+ # for quantile-models, the estimators are grouped by quantiles
555+ if quantile is None :
556+ quantile = likelihood .quantiles [likelihood ._median_idx ]
557+ elif quantile not in self ._model_container :
558+ raise_log (
559+ ValueError (
560+ f"Invalid `quantile={ quantile } `. Must be one of the fitted quantiles "
561+ f"`{ list (self ._model_container .keys ())} `."
562+ ),
563+ logger ,
564+ )
565+ model = self ._model_container [quantile ]
566+ elif quantile is not None :
567+ raise_log (
568+ ValueError (
569+ "`quantile` is only supported for probabilistic models that use `likelihood='quantile'`."
570+ ),
571+ logger = logger ,
572+ )
573+ else :
574+ model = self .model
575+
576+ if not isinstance (model , MultiOutputRegressor ):
545577 logger .warning (
546578 "Model supports multi-output; a single estimator forecasts all the horizons and components."
547579 )
548- return self . model
580+ return model
549581
550582 if not 0 <= horizon < self .output_chunk_length :
551583 raise_log (
@@ -566,27 +598,7 @@ def get_estimator(
566598 idx_estimator = (
567599 self .multi_models * self .input_dim ["target" ] * horizon + target_dim
568600 )
569- if quantile is None :
570- return self .model .estimators_ [idx_estimator ]
571-
572- # for quantile-models, the estimators are also grouped by quantiles
573- if not isinstance (self .likelihood , QuantileRegression ):
574- raise_log (
575- ValueError (
576- "`quantile` is only supported for probabilistic models that "
577- "use `likelihood='quantile'`."
578- ),
579- logger ,
580- )
581- if quantile not in self ._model_container :
582- raise_log (
583- ValueError (
584- f"Invalid `quantile={ quantile } `. Must be one of the fitted quantiles "
585- f"`{ list (self ._model_container .keys ())} `."
586- ),
587- logger ,
588- )
589- return self ._model_container [quantile ].estimators_ [idx_estimator ]
601+ return model .estimators_ [idx_estimator ]
590602
591603 def _add_val_set_to_kwargs (
592604 self ,
0 commit comments