Skip to content

Commit e687fb7

Browse files
authored
Refactor / unify likelihood models (#2742)
* make base likelihood model * move likelihoods to dedicated module * refactor regression model likelihoods * fix failing tests * remove remaining old likelihood references * rename likelihood.likelihood_models to likelihood_models.torch * update changelog * fix imports * address some missed lines * make torch likelihoods backwards compatible * add sklearn likelihood tests * remane likelihood.py to base.py * apply suggestions from PR review * update gaussian sklearn likelihood * rename torch base likelihood * update docs
1 parent 4277679 commit e687fb7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1174
-956
lines changed

CHANGELOG.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,28 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14-
**Removed**
14+
**Removed / moved**
15+
16+
- 🔴 Removed model `AutoARIMA`. To support `numpy>=2.0.0`, we unfortunately had to remove the `pmdarima` dependency. Use `StatsForecastAutoARIMA` instead. [#2734](https://github.com/unit8co/darts/pull/2734) by [Dennis Bader](https://github.com/dennisbader).
1517
- 🔴 Removed deprecated method `TimeSeries.pd_dataframe()`. Use `TimeSeries.to_dataframe()` instead. [#2733](https://github.com/unit8co/darts/pull/2733) by [Dennis Bader](https://github.com/dennisbader).
1618
- 🔴 Removed deprecated method `TimeSeries.pd_serise()`. Use `TimeSeries.to_series()` instead. [#2733](https://github.com/unit8co/darts/pull/2733) by [Dennis Bader](https://github.com/dennisbader).
17-
- 🔴 Removed model `AutoARIMA`. To support `numpy>=2.0.0`, we unfortunately had to remove the `pmdarima` dependency. Use `StatsForecastAutoARIMA` instead. [#2734](https://github.com/unit8co/darts/pull/2734) by [Dennis Bader](https://github.com/dennisbader).
1819

1920
**Fixed**
2021

22+
- Fixed a bug in `CatBoostModel` with `likelihood="gaussian"`, where predicting with `predict_likelihood_parameters=True` resulted in wrong ordering of the predicted parameters. [#2742](https://github.com/unit8co/darts/pull/2742) by [Dennis Bader](https://github.com/dennisbader).
23+
2124
**Dependencies**
2225

2326
### For developers of the library:
2427

28+
- Refactored likelihoods: [#2742](https://github.com/unit8co/darts/pull/2742) by [Dennis Bader](https://github.com/dennisbader).
29+
- Moved all likelihood related files into `darts/utils/likelihood_models/`.
30+
- Added `BaseLikelihood` as a base class for all likelihood models to `darts.utils.likelihood_models.base`.
31+
- Added `RegressionModel` likelihoods to `darts.utils.likelihood_models.sklearn`.
32+
- Moved `TorchForecastingModel` likelihoods to `darts.utils.likelihood_models.torch`. They can still be imported through `darts.utils.likelihood_models`.
33+
- Removed `darts.models.forecasting.regression_model._LikelihoodMixin`. Use dedicated Likelihood models instead.
34+
35+
2536
## [0.34.0](https://github.com/unit8co/darts/tree/0.34.0) (2025-03-09)
2637

2738
### For users of the library:

darts/metrics/metrics.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
from darts import TimeSeries
1818
from darts.dataprocessing import dtw
1919
from darts.logging import get_logger, raise_log
20+
from darts.utils.likelihood_models.base import (
21+
likelihood_component_names,
22+
quantile_names,
23+
)
2024
from darts.utils.ts_utils import SeriesType, get_series_seq_type, series2seq
2125
from darts.utils.utils import (
2226
_build_tqdm_iterator,
2327
_parallel_apply,
24-
likelihood_component_names,
2528
n_steps_between,
26-
quantile_names,
2729
)
2830

2931
logger = get_logger(__name__)

darts/models/forecasting/baselines.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -414,12 +414,8 @@ def _target_average(self, prediction: TimeSeries, series: TimeSeries) -> TimeSer
414414

415415
def _params_average(self, prediction: TimeSeries, series: TimeSeries) -> TimeSeries:
416416
"""Average across the components after grouping by likelihood parameter, rename components"""
417-
# str or torch Likelihood
418-
likelihood = getattr(self.forecasting_models[0], "likelihood")
419-
if isinstance(likelihood, str):
420-
likelihood_n_params = self.forecasting_models[0].num_parameters
421-
else: # Likelihood
422-
likelihood_n_params = likelihood.num_parameters
417+
likelihood = self.forecasting_models[0].likelihood
418+
likelihood_n_params = likelihood.num_parameters
423419
n_forecasting_models = len(self.forecasting_models)
424420
n_components = series.n_components
425421
# aggregate across predictions [model1_param0, model1_param1, ..., modeln_param0, modeln_param1]

darts/models/forecasting/block_rnn_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(
270270
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
271271
Default: ``torch.nn.MSELoss()``.
272272
likelihood
273-
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
273+
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.torch.TorchLikelihood>` models to be used for
274274
probabilistic forecasts. Default: ``None``.
275275
torch_metrics
276276
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found

darts/models/forecasting/catboost_model.py

Lines changed: 33 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,24 @@
1010
from collections.abc import Sequence
1111
from typing import Optional, Union
1212

13-
import numpy as np
1413
from catboost import CatBoostRegressor, Pool
1514

1615
from darts.logging import get_logger
17-
from darts.models.forecasting.regression_model import RegressionModel, _LikelihoodMixin
16+
from darts.models.forecasting.regression_model import (
17+
RegressionModel,
18+
_QuantileModelContainer,
19+
)
1820
from darts.timeseries import TimeSeries
21+
from darts.utils.likelihood_models.sklearn import (
22+
QuantileRegression,
23+
_check_likelihood,
24+
_get_likelihood,
25+
)
1926

2027
logger = get_logger(__name__)
2128

2229

23-
class CatBoostModel(RegressionModel, _LikelihoodMixin):
30+
class CatBoostModel(RegressionModel):
2431
def __init__(
2532
self,
2633
lags: Union[int, list] = None,
@@ -29,7 +36,7 @@ def __init__(
2936
output_chunk_length: int = 1,
3037
output_chunk_shift: int = 0,
3138
add_encoders: Optional[dict] = None,
32-
likelihood: str = None,
39+
likelihood: Optional[str] = None,
3340
quantiles: list = None,
3441
random_state: Optional[int] = None,
3542
multi_models: Optional[bool] = True,
@@ -167,34 +174,33 @@ def encode_year(idx):
167174
"""
168175
kwargs["random_state"] = random_state # seed for tree learner
169176
self.kwargs = kwargs
170-
self._median_idx = None
171177
self._model_container = None
172-
self._rng = None
173-
self._likelihood = likelihood
174-
self.quantiles = None
175-
176-
self._output_chunk_length = output_chunk_length
177-
178-
likelihood_map = {
179-
"quantile": None,
180-
"poisson": "Poisson",
181-
"gaussian": "RMSEWithUncertainty",
182-
"RMSEWithUncertainty": "RMSEWithUncertainty",
183-
}
184-
185-
available_likelihoods = list(likelihood_map.keys())
186178

179+
# parse likelihood
187180
if likelihood is not None:
188-
self._check_likelihood(likelihood, available_likelihoods)
189-
self._rng = np.random.default_rng(seed=random_state) # seed for sampling
181+
likelihood_map = {
182+
"quantile": None,
183+
"poisson": "Poisson",
184+
"gaussian": "RMSEWithUncertainty",
185+
"RMSEWithUncertainty": "RMSEWithUncertainty",
186+
}
187+
_check_likelihood(likelihood, list(likelihood_map.keys()))
188+
if likelihood == "RMSEWithUncertainty":
189+
# RMSEWithUncertainty returns mean and variance which is equivalent to gaussian
190+
likelihood = "gaussian"
190191

191192
if likelihood == "quantile":
192-
self.quantiles, self._median_idx = self._prepare_quantiles(quantiles)
193-
self._model_container = self._get_model_container()
194-
193+
self._model_container = _QuantileModelContainer()
195194
else:
196195
self.kwargs["loss_function"] = likelihood_map[likelihood]
197196

197+
self._likelihood = _get_likelihood(
198+
likelihood=likelihood,
199+
n_outputs=output_chunk_length if multi_models else 1,
200+
random_state=random_state,
201+
quantiles=quantiles,
202+
)
203+
198204
# suppress writing catboost info files when user does not specifically ask to
199205
if "allow_writing_files" not in kwargs:
200206
kwargs["allow_writing_files"] = False
@@ -275,10 +281,11 @@ def fit(
275281
**kwargs
276282
Additional kwargs passed to `catboost.CatboostRegressor.fit()`
277283
"""
278-
if self.likelihood == "quantile":
284+
likelihood = self.likelihood
285+
if isinstance(likelihood, QuantileRegression):
279286
# empty model container in case of multiple calls to fit, e.g. when backtesting
280287
self._model_container.clear()
281-
for quantile in self.quantiles:
288+
for quantile in likelihood.quantiles:
282289
this_quantile = str(quantile)
283290
# translating to catboost argument
284291
self.kwargs["loss_function"] = f"Quantile:alpha={this_quantile}"
@@ -316,42 +323,6 @@ def fit(
316323
)
317324
return self
318325

319-
def _predict_and_sample(
320-
self,
321-
x: np.ndarray,
322-
num_samples: int,
323-
predict_likelihood_parameters: bool,
324-
**kwargs,
325-
) -> np.ndarray:
326-
"""Override of RegressionModel's method to allow for the probabilistic case"""
327-
if self.likelihood in ["gaussian", "RMSEWithUncertainty"]:
328-
return self._predict_and_sample_likelihood(
329-
x, num_samples, "normal", predict_likelihood_parameters, **kwargs
330-
)
331-
elif self.likelihood is not None:
332-
return self._predict_and_sample_likelihood(
333-
x, num_samples, self.likelihood, predict_likelihood_parameters, **kwargs
334-
)
335-
else:
336-
return super()._predict_and_sample(
337-
x, num_samples, predict_likelihood_parameters, **kwargs
338-
)
339-
340-
def _likelihood_components_names(
341-
self, input_series: TimeSeries
342-
) -> Optional[list[str]]:
343-
"""Override of RegressionModel's method to support the gaussian/normal likelihood"""
344-
if self.likelihood == "quantile":
345-
return self._quantiles_generate_components_names(input_series)
346-
elif self.likelihood == "poisson":
347-
return self._likelihood_generate_components_names(input_series, ["lamba"])
348-
elif self.likelihood in ["gaussian", "RMSEWithUncertainty"]:
349-
return self._likelihood_generate_components_names(
350-
input_series, ["mu", "sigma"]
351-
)
352-
else:
353-
return None
354-
355326
def _add_val_set_to_kwargs(
356327
self,
357328
kwargs: dict,
@@ -386,10 +357,6 @@ def _add_val_set_to_kwargs(
386357
kwargs[val_set_name] = val_pools
387358
return kwargs
388359

389-
@property
390-
def supports_probabilistic_prediction(self) -> bool:
391-
return self.likelihood is not None
392-
393360
@property
394361
def supports_val_set(self) -> bool:
395362
return True

darts/models/forecasting/conformal_models.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
from collections.abc import Sequence
1414
from typing import Any, BinaryIO, Callable, Optional, Union
1515

16+
from darts.utils.likelihood_models.base import (
17+
Likelihood,
18+
LikelihoodType,
19+
quantile_names,
20+
)
21+
1622
try:
1723
from typing import Literal
1824
except ImportError:
@@ -46,9 +52,7 @@
4652
from darts.utils.utils import (
4753
_check_quantiles,
4854
generate_index,
49-
likelihood_component_names,
5055
n_steps_between,
51-
quantile_names,
5256
random_method,
5357
sample_from_quantiles,
5458
)
@@ -180,7 +184,10 @@ def __init__(
180184
self.cal_num_samples = (
181185
cal_num_samples if model.supports_probabilistic_prediction else 1
182186
)
183-
self._likelihood = "quantile"
187+
self._likelihood = Likelihood(
188+
likelihood_type=LikelihoodType.Quantile,
189+
parameter_names=quantile_names(quantiles),
190+
)
184191
self._fit_called = True
185192

186193
def fit(
@@ -1251,7 +1258,7 @@ def conformal_predict(idx_, pred_vals_):
12511258
inner_iterator = enumerate(s_hfcs[first_fc_idx:last_fc_idx:rel_stride])
12521259

12531260
comp_names_out = (
1254-
self._cp_component_names(series_)
1261+
self.likelihood.component_names(series_)
12551262
if predict_likelihood_parameters
12561263
else None
12571264
)
@@ -1426,12 +1433,6 @@ def _residuals_metric(self) -> tuple[METRIC_TYPE, Optional[dict]]:
14261433
"""Gives the "per time step" metric and optional metric kwargs used to compute residuals /
14271434
non-conformity scores."""
14281435

1429-
def _cp_component_names(self, input_series) -> list[str]:
1430-
"""Gives the component names for generated forecasts."""
1431-
return likelihood_component_names(
1432-
input_series.components, quantile_names(self.quantiles)
1433-
)
1434-
14351436
def _historical_forecasts_sanity_checks(self, *args: Any, **kwargs: Any) -> None:
14361437
super()._historical_forecasts_sanity_checks(*args, **kwargs, is_conformal=True)
14371438

@@ -1516,7 +1517,7 @@ def considers_static_covariates(self) -> bool:
15161517
return self.model.considers_static_covariates
15171518

15181519
@property
1519-
def likelihood(self) -> str:
1520+
def likelihood(self) -> Likelihood:
15201521
return self._likelihood
15211522

15221523

darts/models/forecasting/dlinear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def __init__(
294294
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
295295
Default: ``torch.nn.MSELoss()``.
296296
likelihood
297-
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
297+
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.torch.TorchLikelihood>` models to be used for
298298
probabilistic forecasts. Default: ``None``.
299299
torch_metrics
300300
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found

darts/models/forecasting/ensemble_model.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from collections.abc import Sequence
99
from typing import BinaryIO, Optional, Union
1010

11+
from darts.utils.likelihood_models.base import LikelihoodType
12+
1113
if sys.version_info >= (3, 11):
1214
from typing import Self
1315
else:
@@ -563,19 +565,13 @@ def _models_same_likelihood(self) -> bool:
563565
lkl_same_params = True
564566
tmp_quantiles = None
565567
for m in self.forecasting_models:
566-
# regression model likelihood is a string, torch-based model likelihoods is an object
567-
likelihood = getattr(m, "likelihood")
568-
is_obj_lkl = not isinstance(likelihood, str)
569-
lkl_simplified_name = (
570-
likelihood.simplified_name() if is_obj_lkl else likelihood
571-
)
572-
models_likelihood.add(lkl_simplified_name)
568+
likelihood = m.likelihood
569+
lkl_type = likelihood.type
570+
models_likelihood.add(lkl_type)
573571

574572
# check the quantiles
575-
if lkl_simplified_name == "quantile":
576-
quantiles: list[str] = (
577-
likelihood.quantiles if is_obj_lkl else m.quantiles
578-
)
573+
if lkl_type is LikelihoodType.Quantile:
574+
quantiles: list[str] = likelihood.quantiles
579575
if tmp_quantiles is None:
580576
tmp_quantiles = quantiles
581577
elif tmp_quantiles != quantiles:

darts/models/forecasting/forecasting_model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@
2727
from random import sample
2828
from typing import Any, BinaryIO, Callable, Literal, Optional, Union
2929

30+
from darts.utils.likelihood_models.base import (
31+
Likelihood,
32+
likelihood_component_names,
33+
quantile_interval_names,
34+
quantile_names,
35+
)
36+
3037
if sys.version_info >= (3, 11):
3138
from typing import Self
3239
else:
@@ -68,9 +75,6 @@
6875
)
6976
from darts.utils.utils import (
7077
generate_index,
71-
likelihood_component_names,
72-
quantile_interval_names,
73-
quantile_names,
7478
)
7579

7680
logger = get_logger(__name__)
@@ -206,6 +210,11 @@ def _supports_range_index(self) -> bool:
206210
"""
207211
return True
208212

213+
@property
214+
def likelihood(self) -> Optional[Likelihood]:
215+
"""Returns the likelihood (if any) that the model uses for probabilistic forecasts."""
216+
return None
217+
209218
@property
210219
def supports_probabilistic_prediction(self) -> bool:
211220
"""
@@ -267,7 +276,7 @@ def supports_likelihood_parameter_prediction(self) -> bool:
267276
"""
268277
Whether model instance supports direct prediction of likelihood parameters
269278
"""
270-
return getattr(self, "likelihood", None) is not None
279+
return self.likelihood is not None
271280

272281
@property
273282
@abstractmethod

0 commit comments

Comments
 (0)