@@ -3596,6 +3596,10 @@ def test_categorical_features_passed_to_fit_correctly(self, model_cls_and_module
35963596 future_covariates ,
35973597 ) = self .inputs_for_tests_categorical_covariates ()
35983598
3599+ series_train , series_val = series .split_after (0.6 )
3600+ past_cov_train , past_cov_val = past_covariates .split_after (0.6 )
3601+ future_cov_train , future_cov_val = future_covariates .split_after (0.6 )
3602+
35993603 original_fit = model .model .fit
36003604 intercepted_args = {}
36013605
@@ -3610,24 +3614,39 @@ def intercept_fit_args(*args, **kwargs):
36103614 side_effect = intercept_fit_args ,
36113615 ):
36123616 model .fit (
3613- series = series ,
3614- past_covariates = past_covariates ,
3615- future_covariates = future_covariates ,
3617+ series = series_train ,
3618+ past_covariates = past_cov_train ,
3619+ future_covariates = future_cov_train ,
3620+ val_series = series_val ,
3621+ val_past_covariates = past_cov_val ,
3622+ val_future_covariates = future_cov_val ,
36163623 )
36173624
36183625 expected_cat_indices = [2 , 3 , 5 ]
36193626 cat_param_name = model ._categorical_fit_param
3627+ eval_set_param_name , _ = model .val_set_params
36203628 if model_cls == CatBoostModel :
36213629 model_cat_indices = model .model .get_cat_feature_indices ()
36223630 kwargs_cat_indices = intercepted_args ["kwargs" ][cat_param_name ]
3631+
36233632 assert model_cat_indices == kwargs_cat_indices == expected_cat_indices
36243633
3634+ # all evals set have correct cat feature indices
3635+ eval_set_indices = [
3636+ pool .get_cat_feature_indices ()
3637+ for pool in intercepted_args ["kwargs" ][eval_set_param_name ]
3638+ ]
3639+ assert np .array ([
3640+ indices == model_cat_indices for indices in eval_set_indices
3641+ ]).all ()
3642+
36253643 # catboost requires pd.DataFrame with categorical features
36263644 X , y = intercepted_args ["args" ]
36273645 assert isinstance (X , pd .DataFrame )
36283646 # all categorical features should be encoded as integers
36293647 for col in X [model_cat_indices ].columns :
36303648 assert X [col ].dtype == int
3649+
36313650 elif model_cls == LightGBMModel :
36323651 assert (
36333652 intercepted_args ["kwargs" ][cat_param_name ] == expected_cat_indices
0 commit comments