Skip to content

Commit 6a1f80e

Browse files
KylinSchmidtmadtoinoudennisbader
authored
Fix/val_sample_weight error for models inherited from RegressionModel (#2626)
* Fix/val_sample_weight error for models inherited from RegressionModel * modify change log * modify the statement * update changelog * add unit tests * Update CHANGELOG.md --------- Co-authored-by: madtoinou <[email protected]> Co-authored-by: dennisbader <[email protected]>
1 parent bf2373a commit 6a1f80e

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
2020
- Improvements to `ForecastingModel`:
2121
- Added parameter `clean: bool` to `ForecastingModel.save()` to store a cleaned version of the model (removes training data from global models, and Lightning Trainer-related parameters from torch models). [#2649](https://github.com/unit8co/darts/pull/2649) by [Jonas Blanc](https://github.com/jonasblanc).
2222
- Added parameter `pl_trainer_kwargs` to `TorchForecastingModel.load()` to setup a new Lightning Trainer used to configure the model for downstream tasks (e.g. prediction). [#2649](https://github.com/unit8co/darts/pull/2649) by [Jonas Blanc](https://github.com/jonasblanc).
23+
- Fixed a bug in `LightGBMModel`, `XGBModel`, and `CatBoostModel` which raised an error when calling `fit()` with `val_sample_weight`. [#2626](https://github.com/unit8co/darts/pull/2626) by [Kylin Schmidt](https://github.com/kylinschmidt).
2324
- Improved the documentation of how `WindowedAnomalyScorer` extract the training data from the input series. [#2674](https://github.com/unit8co/darts/pull/2674) by [Dennis Bader](https://github.com/dennisbader).
2425

2526
**Fixed**

darts/models/forecasting/regression_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def _add_val_set_to_kwargs(
591591
val_weights = val_weights or None
592592
else:
593593
val_sets = [(val_samples, val_labels)]
594-
val_weights = val_weight
594+
val_weights = [val_weight]
595595

596596
val_set_name, val_weight_name = self.val_set_params
597597
return dict(kwargs, **{val_set_name: val_sets, val_weight_name: val_weights})

darts/tests/models/forecasting/test_regression_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,6 +1867,34 @@ def test_not_enough_covariates(self, config):
18671867
future_covariates=future_covariates[: -26 + req_future_offset],
18681868
)
18691869

1870+
@pytest.mark.parametrize(
1871+
"config",
1872+
product(
1873+
[(XGBModel, xgb_test_params)]
1874+
+ ([(LightGBMModel, lgbm_test_params)] if lgbm_available else [])
1875+
+ ([(CatBoostModel, cb_test_params)] if cb_available else []),
1876+
[True, False],
1877+
),
1878+
)
1879+
def test_val_set_weights_runnability_trees(self, config):
1880+
"""Tests using weights in val set for single and multi series."""
1881+
(model_cls, model_kwargs), single_series = config
1882+
model = model_cls(lags=10, **model_kwargs)
1883+
1884+
series = tg.sine_timeseries(length=20)
1885+
weights = tg.linear_timeseries(length=20)
1886+
if not single_series:
1887+
series = [series] * 2
1888+
weights = [weights] * 2
1889+
1890+
model.fit(
1891+
series=series,
1892+
val_series=series,
1893+
sample_weight=weights,
1894+
val_sample_weight=weights,
1895+
)
1896+
_ = model.predict(1, series=series)
1897+
18701898
@pytest.mark.parametrize(
18711899
"config",
18721900
product(

0 commit comments

Comments
 (0)