Skip to content

Commit 837ebc8

Browse files
Fix/categorical validation features (#2776)
* Format validation set * Add cat_feature to validation set * Test categorical validation set * Update darts/tests/models/forecasting/test_regression_models.py Co-authored-by: madtoinou <[email protected]> --------- Co-authored-by: madtoinou <[email protected]>
1 parent 836436d commit 837ebc8

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

darts/models/forecasting/catboost_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def _add_val_set_to_kwargs(
374374
data=val_set[0],
375375
label=val_set[1],
376376
weight=val_weights[i] if val_weights is not None else None,
377+
cat_features=self._categorical_indices,
377378
)
378379
)
379380
kwargs[val_set_name] = val_pools

darts/models/forecasting/regression_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,8 @@ def _create_lagged_data(
677677
):
678678
sample_weights = sample_weights.ravel()
679679

680+
features, labels = self._format_samples(features, labels)
681+
680682
return features, labels, sample_weights
681683

682684
def _format_samples(
@@ -733,9 +735,7 @@ def _fit_model(
733735
"`sample_weight` was ignored since underlying regression model's "
734736
"`fit()` method does not support it."
735737
)
736-
training_samples, training_labels = self._format_samples(
737-
training_samples, training_labels
738-
)
738+
739739
self.model.fit(
740740
training_samples, training_labels, **sample_weight_kwargs, **kwargs
741741
)

darts/tests/models/forecasting/test_regression_models.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3596,6 +3596,10 @@ def test_categorical_features_passed_to_fit_correctly(self, model_cls_and_module
35963596
future_covariates,
35973597
) = self.inputs_for_tests_categorical_covariates()
35983598

3599+
series_train, series_val = series.split_after(0.6)
3600+
past_cov_train, past_cov_val = past_covariates.split_after(0.6)
3601+
future_cov_train, future_cov_val = future_covariates.split_after(0.6)
3602+
35993603
original_fit = model.model.fit
36003604
intercepted_args = {}
36013605

@@ -3610,24 +3614,39 @@ def intercept_fit_args(*args, **kwargs):
36103614
side_effect=intercept_fit_args,
36113615
):
36123616
model.fit(
3613-
series=series,
3614-
past_covariates=past_covariates,
3615-
future_covariates=future_covariates,
3617+
series=series_train,
3618+
past_covariates=past_cov_train,
3619+
future_covariates=future_cov_train,
3620+
val_series=series_val,
3621+
val_past_covariates=past_cov_val,
3622+
val_future_covariates=future_cov_val,
36163623
)
36173624

36183625
expected_cat_indices = [2, 3, 5]
36193626
cat_param_name = model._categorical_fit_param
3627+
eval_set_param_name, _ = model.val_set_params
36203628
if model_cls == CatBoostModel:
36213629
model_cat_indices = model.model.get_cat_feature_indices()
36223630
kwargs_cat_indices = intercepted_args["kwargs"][cat_param_name]
3631+
36233632
assert model_cat_indices == kwargs_cat_indices == expected_cat_indices
36243633

3634+
# all evals set have correct cat feature indices
3635+
eval_set_indices = [
3636+
pool.get_cat_feature_indices()
3637+
for pool in intercepted_args["kwargs"][eval_set_param_name]
3638+
]
3639+
assert np.array([
3640+
indices == model_cat_indices for indices in eval_set_indices
3641+
]).all()
3642+
36253643
# catboost requires pd.DataFrame with categorical features
36263644
X, y = intercepted_args["args"]
36273645
assert isinstance(X, pd.DataFrame)
36283646
# all categorical features should be encoded as integers
36293647
for col in X[model_cat_indices].columns:
36303648
assert X[col].dtype == int
3649+
36313650
elif model_cls == LightGBMModel:
36323651
assert (
36333652
intercepted_args["kwargs"][cat_param_name] == expected_cat_indices

0 commit comments

Comments
 (0)