Skip to content

Commit cf5fa83

Browse files
jonasblancjonasblancdennisbader
authored
Classification forecasting for regression models (#2765)
* Add categorical cov support to XGBoost, CatBoost * Add type check for cat features, refactor cat indices logic * Split cat. comp. validation logic and test it * Support categorical cov created via an encoder * Validate categorical features * Support categorical features for HistGradientBoostingRegressor * Fix typos * Apply suggestions, limit cat cov support to LightGBM and CatBoost * Update changelog and doc * Fix typo in TS doc * Rebase cat forecasting PR, on cat covariates PR * Speed up tests by limiting lgbm and catboost depth and iterations * Extend test categorical target * Add categorical cov support to XGBoost, CatBoost * Fix typo in TS doc * Rebase cat forecasting PR, on cat covariates PR * Speed up tests by limiting lgbm and catboost depth and iterations * Extend test categorical target * Add classification accuracy metric * Fix master rebase * Fix typo rebase * Keep categorical metrics for separate PR * Add categorical forecasting models to module __init__ * Refactor MutliOutput to support MultiOutputClassifier, wip * Move _forecasting_type into CategoricalForecastingMixin * Further refactoring of multioutput wrapper * Implement ClassProbabilityLikelihood to forecast categorical probabilities * Add support for ClassProbabilityLikelihood in XGB, CatBoost and LGB * Reorder functions * Create categorical forecasting likelihood specific tests * Add docstring to CatBoostCategoricalModel * Add LightGBMCategoricalModel to model module init * Allow likelihoodType in _check_likelihood * Update doc * Update ClassProbability name from class_probability to classprobability * Remove default model, update doc * Rename categorical forecasting to classification forecasting * Set ClassProbabilityLikelihood as default for all classifiers models * Update darts/models/forecasting/regression_model.py Co-authored-by: Dennis Bader <[email protected]> * Update darts/utils/multioutput.py Co-authored-by: Dennis Bader <[email protected]> * Addresses review suggestions * Move ClassProbabilityLikelihood to sklearn likelihood * Extends classes_ test to multi-output * Expose likelihood in classifiers constructor * Bump test env to macos-14 * Fix test * Add multioutput validation test * Fix categorical validation features * Extend multi-ouput tests * fix merge conflicts * Extend likelihood tests * Merge _check_likelihood into _get_likelihood * Address suggestions * Rename .classes_ to .class_labels, fix tests * Check estimators for same component have same labels * Improve ClassProbabilityLikelihood robustness to input format * Add input format to tests * Extend test case for multioutput wrapper * Improve ClassProbabilitiy robustness to TS formats * Move and refactor classes in CLassifierMixin * Refactor internal class proba representation * Extend test to component names and warnings * Fix test randomness * Extend class probabilites checks to multivariate/mulitseries * Unify CatBoostClassifier prediction shape * Improve robustness of multi sample prediction * Refactor probabilistic tests * Return self on fit * Test ClassProbability for reproducible output * Update changelog * Test edge case multioutput/likelihood * Address small suggestions from review * Optimize likelihood sampling * update class probability likelihood component names * Fix sample, add tests * Apply minor suggestions * Fix lint * Fix merge * remove random state params * udpate tests * minor fixes * add example notebook * add first backtesting tests * remove examples * last updates --------- Co-authored-by: jonasblanc <[email protected]> Co-authored-by: Dennis Bader <[email protected]>
1 parent 8fe6985 commit cf5fa83

File tree

18 files changed

+3214
-272
lines changed

18 files changed

+3214
-272
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- Added support for classification forecasting with SKLearn-like forecasting models. [#2765](https://github.com/unit8co/darts/pull/2765) by [Jonas Blanc](https://github.com/jonasblanc) and [Dennis Bader](https://github.com/dennisbader).
15+
- Added `SklearnClassifierModel` which can take any sklearn-like classifier model.
16+
- Added `LightGBMClassifierModel`, `XGBClassifierModel` and `CatBoostClassifierModel` which use the classifier models of the respective libraries.
17+
- Added `ClassProbabilityLikelihood` and set it as the default likelihood for classifiers to predict class probabilities with `predict_likelihood_parameters=True` when calling `predict()`.
1418
- Added classification metrics `accuracy()`, `f1()`, `precision()`, and `recall()`, `confusion_matrix()` to the `metrics` module. Use these metrics to evaluate the performance of classification models. [#2767](https://github.com/unit8co/darts/pull/2767) by [Jonas Blanc](https://github.com/jonasblanc) and [Dennis Bader](https://github.com/dennisbader).
1519

1620
**Fixed**
@@ -21,6 +25,9 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2125

2226
### For developers of the library:
2327

28+
- Renamed `RegressionModelWithCategoricalCovariates` to `RegressionModelWithCategoricalFeatures` which now also supports categorical target features. [#2765](https://github.com/unit8co/darts/pull/2765) by [Jonas Blanc](https://github.com/jonasblanc)
29+
- Added `MultiOutputClassifier` for handling multi-output classification tasks. [#2765](https://github.com/unit8co/darts/pull/2765) by [Jonas Blanc](https://github.com/jonasblanc)
30+
2431
## [0.36.0](https://github.com/unit8co/darts/tree/0.36.0) (2025-06-29)
2532

2633
### For users of the library:

darts/metrics/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,14 @@
154154
incs_qr,
155155
}
156156

157+
CLASSIFICATION_METRICS = {
158+
accuracy,
159+
precision,
160+
recall,
161+
f1,
162+
confusion_matrix,
163+
}
164+
157165
__all__ = [
158166
"ae",
159167
"ape",

darts/metrics/utils.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@
3232
_FN_IDX = (slice(None), slice(None), 1, 0)
3333
_TP_IDX = (slice(None), slice(None), 1, 1)
3434

35+
# parameter names
36+
_PARAM_Q = "q"
37+
_PARAM_Q_INTERVAL = "q_interval"
38+
_PARAM_LABELS = "labels"
39+
_PARAM_LABEL_REDUCTION = "label_reduction"
40+
_PARAM_SERIES_REDUCTION = "series_reduction"
41+
_PARAM_TIME_REDUCTION = "time_reduction"
42+
_PARAM_COMPONENT_REDUCTION = "component_reduction"
3543

3644
# class probabilities suffix
3745
PROBA_SUFFIX = "_p"
@@ -68,38 +76,40 @@ def interval_support(func) -> Callable[..., METRIC_OUTPUT_TYPE]:
6876

6977
@wraps(func)
7078
def wrapper_interval_support(*args, **kwargs):
71-
q = kwargs.get("q")
79+
q = kwargs.get(_PARAM_Q)
7280
if q is not None:
7381
raise_log(
7482
ValueError(
75-
"`q` is not supported for quantile interval metrics; use `q_interval` instead."
83+
f"`{_PARAM_Q}` is not supported for quantile interval metrics; use `{_PARAM_Q_INTERVAL}` instead."
7684
)
7785
)
78-
q_interval = kwargs.get("q_interval")
86+
q_interval = kwargs.get(_PARAM_Q_INTERVAL)
7987
if q_interval is None:
8088
raise_log(
81-
ValueError("Quantile interval metrics require setting `q_interval`.")
89+
ValueError(
90+
f"Quantile interval metrics require setting `{_PARAM_Q_INTERVAL}`."
91+
)
8292
)
8393
if isinstance(q_interval, tuple):
8494
q_interval = [q_interval]
8595
q_interval = np.array(q_interval)
8696
if not q_interval.ndim == 2 or q_interval.shape[1] != 2:
8797
raise_log(
8898
ValueError(
89-
"`q_interval` must be a tuple (float, float) or a sequence of tuples (float, float)."
99+
f"`{_PARAM_Q_INTERVAL}` must be a tuple (float, float) or a sequence of tuples (float, float)."
90100
),
91101
logger=logger,
92102
)
93103
if not np.all(q_interval[:, 1] - q_interval[:, 0] > 0):
94104
raise_log(
95105
ValueError(
96-
"all intervals in `q_interval` must be tuples of (lower q, upper q) with `lower q > upper q`. "
97-
f"Received `q_interval={q_interval}`"
106+
f"all intervals in `{_PARAM_Q_INTERVAL}` must be tuples of (lower q, upper q) with "
107+
f"`lower q > upper q`. Received `{_PARAM_Q_INTERVAL}={q_interval}`"
98108
),
99109
logger=logger,
100110
)
101-
kwargs["q_interval"] = q_interval
102-
kwargs["q"] = np.sort(np.unique(q_interval))
111+
kwargs[_PARAM_Q_INTERVAL] = q_interval
112+
kwargs[_PARAM_Q] = np.sort(np.unique(q_interval))
103113
return func(*args, **kwargs)
104114

105115
return wrapper_interval_support
@@ -113,29 +123,29 @@ def classification_support(func) -> Callable[..., METRIC_OUTPUT_TYPE]:
113123

114124
@wraps(func)
115125
def wrapper_classification_support(*args, **kwargs):
116-
labels = kwargs.get("labels")
126+
labels = kwargs.get(_PARAM_LABELS)
117127
if labels is not None:
118128
if isinstance(labels, int):
119129
labels = np.array([labels])
120130
else:
121131
labels = np.array(labels)
122-
kwargs["labels"] = labels
132+
kwargs[_PARAM_LABELS] = labels
123133

124134
params = signature(func).parameters
125-
if "label_reduction" in params:
135+
if _PARAM_LABEL_REDUCTION in params:
126136
label_reduction = kwargs.get(
127-
"label_reduction", params["label_reduction"].default
137+
_PARAM_LABEL_REDUCTION, params[_PARAM_LABEL_REDUCTION].default
128138
)
129139
if not isinstance(label_reduction, _LabelReduction):
130140
if not _LabelReduction.has_value(label_reduction):
131141
raise_log(
132142
ValueError(
133-
f"Invalid `label_reduction` value: `{label_reduction}`. "
143+
f"Invalid `{_PARAM_LABEL_REDUCTION}` value: `{label_reduction}`. "
134144
f"Must be one of `{list(_LabelReduction._value2member_map_)}`."
135145
),
136146
logger=logger,
137147
)
138-
kwargs["label_reduction"] = _LabelReduction(label_reduction)
148+
kwargs[_PARAM_LABEL_REDUCTION] = _LabelReduction(label_reduction)
139149

140150
kwargs["is_classification"] = True
141151
return func(*args, **kwargs)
@@ -176,21 +186,21 @@ def wrapper_multi_ts_support(*args, **kwargs):
176186
_ = _get_reduction(
177187
kwargs=kwargs,
178188
params=params,
179-
red_name="time_reduction",
189+
red_name=_PARAM_TIME_REDUCTION,
180190
axis=TIME_AX,
181191
sanity_check=True,
182192
)
183193
_ = _get_reduction(
184194
kwargs=kwargs,
185195
params=params,
186-
red_name="component_reduction",
196+
red_name=_PARAM_COMPONENT_REDUCTION,
187197
axis=COMP_AX,
188198
sanity_check=True,
189199
)
190200
series_reduction = _get_reduction(
191201
kwargs=kwargs,
192202
params=params,
193-
red_name="series_reduction",
203+
red_name=_PARAM_SERIES_REDUCTION,
194204
axis=0,
195205
sanity_check=True,
196206
)
@@ -237,12 +247,12 @@ def wrapper_multi_ts_support(*args, **kwargs):
237247
kwargs.pop("insample", 0)
238248

239249
# handle `q` (quantile) parameter for probabilistic (or quantile) forecasts
240-
if "q" in params:
250+
if _PARAM_Q in params:
241251
# convert `q` to tuple of (quantile values, optional quantile component names)
242-
q = kwargs.get("q", params["q"].default)
252+
q = kwargs.get(_PARAM_Q, params[_PARAM_Q].default)
243253
q_comp_names = None
244254
if q is None:
245-
kwargs["q"] = None
255+
kwargs[_PARAM_Q] = None
246256
else:
247257
if isinstance(q, tuple):
248258
q, q_comp_names = q
@@ -254,19 +264,19 @@ def wrapper_multi_ts_support(*args, **kwargs):
254264
if not np.all(q[1:] - q[:-1] > 0.0):
255265
raise_log(
256266
ValueError(
257-
"`q` must be of type `float`, or a sequence of increasing order with unique values only. "
258-
f"Received `q={q}`."
267+
f"`{_PARAM_Q}` must be of type `float`, or a sequence of increasing order with unique "
268+
f"values only. Received `{_PARAM_Q}={q}`."
259269
),
260270
logger=logger,
261271
)
262272
if not np.all(q >= 0.0) & np.all(q <= 1.0):
263273
raise_log(
264274
ValueError(
265-
f"All `q` values must be in the range `(>=0,<=1)`. Received `q={q}`."
275+
f"All `{_PARAM_Q}` values must be in the range `(>=0,<=1)`. Received `{_PARAM_Q}={q}`."
266276
),
267277
logger=logger,
268278
)
269-
kwargs["q"] = (q, q_comp_names)
279+
kwargs[_PARAM_Q] = (q, q_comp_names)
270280

271281
iterator = _build_tqdm_iterator(
272282
iterable=zip(*input_series),
@@ -299,7 +309,7 @@ def wrapper_multi_ts_support(*args, **kwargs):
299309

300310
# reduce metrics along series axis
301311
if series_reduction is not None:
302-
vals = kwargs["series_reduction"](vals, axis=0)
312+
vals = kwargs[_PARAM_SERIES_REDUCTION](vals, axis=0)
303313
elif series_seq_type == SeriesType.SINGLE:
304314
vals = vals[0]
305315

@@ -374,7 +384,7 @@ def wrapper_multivariate_support(*args, **kwargs) -> METRIC_OUTPUT_TYPE:
374384
time_reduction = _get_reduction(
375385
kwargs=kwargs,
376386
params=params,
377-
red_name="time_reduction",
387+
red_name=_PARAM_TIME_REDUCTION,
378388
axis=TIME_AX,
379389
sanity_check=False,
380390
)
@@ -385,7 +395,7 @@ def wrapper_multivariate_support(*args, **kwargs) -> METRIC_OUTPUT_TYPE:
385395
component_reduction = _get_reduction(
386396
kwargs=kwargs,
387397
params=params,
388-
red_name="component_reduction",
398+
red_name=_PARAM_COMPONENT_REDUCTION,
389399
axis=COMP_AX,
390400
sanity_check=False,
391401
)
@@ -407,7 +417,7 @@ def wrapper_multivariate_support(*args, **kwargs) -> METRIC_OUTPUT_TYPE:
407417

408418
def _regression_handling(actual_series, pred_series, params, kwargs):
409419
"""Handles the regression metrics input parameters and checks."""
410-
q, q_comp_names = kwargs.get("q"), None
420+
q, q_comp_names = kwargs.get(_PARAM_Q), None
411421
if q is None:
412422
# without quantiles, the number of components must match
413423
if actual_series.n_components != pred_series.n_components:
@@ -426,9 +436,9 @@ def _regression_handling(actual_series, pred_series, params, kwargs):
426436
if not isinstance(q, tuple) or not len(q) == 2:
427437
raise_log(
428438
ValueError(
429-
"`q` must be of tuple of `(np.ndarray, Optional[pd.Index])` "
439+
f"`{_PARAM_Q}` must be of tuple of `(np.ndarray, Optional[pd.Index])` "
430440
"where the (quantile values, optional quantile component names). "
431-
f"Received `q={q}`."
441+
f"Received `{_PARAM_Q}={q}`."
432442
),
433443
logger=logger,
434444
)
@@ -446,16 +456,16 @@ def _regression_handling(actual_series, pred_series, params, kwargs):
446456
if not q_comp_names.isin(pred_series.components).all():
447457
raise_log(
448458
ValueError(
449-
f"Computing a metric with quantile(s) `q={q}` is only supported for probabilistic "
459+
f"Computing a metric with quantile(s) `{_PARAM_Q}={q}` is only supported for probabilistic "
450460
f"`pred_series` (num samples > 1) or `pred_series` containing the predicted "
451461
f"quantiles as columns / components. Either pass a probabilistic `pred_series` or "
452462
f"a series containing the expected quantile components: {q_comp_names.tolist()} "
453463
),
454464
logger=logger,
455465
)
456466

457-
if "q" in params:
458-
kwargs["q"] = (q, q_comp_names)
467+
if _PARAM_Q in params:
468+
kwargs[_PARAM_Q] = (q, q_comp_names)
459469
return kwargs
460470

461471

darts/models/__init__.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
try:
1313
# `lightgbm` needs to be imported first to avoid segmentation fault
14-
from darts.models.forecasting.lgbm import LightGBMModel
14+
from darts.models.forecasting.lgbm import LightGBMClassifierModel, LightGBMModel
1515
except ModuleNotFoundError:
1616
LightGBMModel = NotImportedModule(module_name="LightGBM", warn=False)
17+
LightGBMClassifierModel = NotImportedModule(module_name="LightGBM", warn=False)
1718

1819
# Forecasting
1920
from darts.models.forecasting.arima import ARIMA
@@ -35,7 +36,11 @@
3536
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
3637
from darts.models.forecasting.random_forest import RandomForest, RandomForestModel
3738
from darts.models.forecasting.regression_ensemble_model import RegressionEnsembleModel
38-
from darts.models.forecasting.sklearn_model import RegressionModel, SKLearnModel
39+
from darts.models.forecasting.sklearn_model import (
40+
RegressionModel,
41+
SKLearnClassifierModel,
42+
SKLearnModel,
43+
)
3944
from darts.models.forecasting.theta import FourTheta, Theta
4045
from darts.models.forecasting.varima import VARIMA
4146

@@ -83,9 +88,14 @@
8388
Prophet = NotImportedModule(module_name="Prophet", warn=False)
8489

8590
try:
86-
from darts.models.forecasting.catboost_model import CatBoostModel
91+
from darts.models.forecasting.catboost_model import (
92+
CatBoostClassifierModel,
93+
CatBoostModel,
94+
)
8795
except ModuleNotFoundError:
8896
CatBoostModel = NotImportedModule(module_name="CatBoost", warn=False)
97+
CatBoostClassifierModel = NotImportedModule(module_name="CatBoost", warn=False)
98+
8999

90100
try:
91101
from darts.models.forecasting.sf_auto_arima import AutoARIMA
@@ -116,9 +126,10 @@
116126
AutoTBATS = NotImportedModule(module_name="StatsForecast", warn=False)
117127

118128
try:
119-
from darts.models.forecasting.xgboost import XGBModel
129+
from darts.models.forecasting.xgboost import XGBClassifierModel, XGBModel
120130
except ImportError:
121131
XGBModel = NotImportedModule(module_name="XGBoost")
132+
XGBClassifierModel = NotImportedModule(module_name="XGBoost")
122133

123134
# Filtering
124135
from darts.models.filtering.gaussian_process_filter import GaussianProcessFilter
@@ -127,6 +138,7 @@
127138

128139
__all__ = [
129140
"LightGBMModel",
141+
"LightGBMClassifierModel",
130142
"ARIMA",
131143
"NaiveDrift",
132144
"NaiveMean",
@@ -140,6 +152,7 @@
140152
"RandomForestModel",
141153
"RegressionEnsembleModel",
142154
"SKLearnModel",
155+
"SKLearnClassifierModel",
143156
"RegressionModel",
144157
"TBATS",
145158
"FourTheta",
@@ -161,6 +174,7 @@
161174
"TSMixerModel",
162175
"Prophet",
163176
"CatBoostModel",
177+
"CatBoostClassifierModel",
164178
"Croston",
165179
"AutoARIMA",
166180
"AutoCES",
@@ -170,6 +184,7 @@
170184
"AutoTBATS",
171185
"StatsForecastModel",
172186
"XGBModel",
187+
"XGBClassifierModel",
173188
"GaussianProcessFilter",
174189
"KalmanFilter",
175190
"MovingAverageFilter",

0 commit comments

Comments
 (0)