22import pytest
33from sklearn .datasets import make_classification
44from sklearn .linear_model import LogisticRegression
5+ from sklearn .metrics import make_scorer , precision_score
56from sklearn .model_selection import cross_validate
67from sklearn .utils ._testing import assert_allclose
78
1112@pytest .fixture
1213def data ():
1314 return make_classification (
14- weights = [0.9 , 0.1 ],
15- class_sep = 2 ,
15+ weights = [0.5 , 0.5 ],
16+ class_sep = 0.5 ,
1617 n_informative = 3 ,
1718 n_redundant = 1 ,
1819 flip_y = 0.05 ,
19- n_samples = 1000 ,
20+ n_samples = 50 ,
2021 random_state = 10 ,
2122 )
2223
2324
2425def test_groups_parameter_warning (data ):
2526 """Test that a warning is raised when groups parameter is provided."""
2627 X , y = data
27- ih_cv = InstanceHardnessCV (estimator = LogisticRegression ())
28+ ih_cv = InstanceHardnessCV (estimator = LogisticRegression (), n_splits = 3 )
2829
2930 warning_msg = "The groups parameter is ignored by InstanceHardnessCV"
3031 with pytest .warns (UserWarning , match = warning_msg ):
@@ -42,9 +43,11 @@ def test_error_on_multiclass():
4243def test_default_params (data ):
4344 """Test that the default parameters are used."""
4445 X , y = data
45- ih_cv = InstanceHardnessCV (estimator = LogisticRegression ())
46- cv_result = cross_validate (LogisticRegression (), X , y , cv = ih_cv )
47- assert_allclose (cv_result ["test_score" ], [0.975 , 0.965 , 0.96 , 0.955 , 0.965 ])
46+ ih_cv = InstanceHardnessCV (estimator = LogisticRegression (), n_splits = 3 )
47+ cv_result = cross_validate (
48+ LogisticRegression (), X , y , cv = ih_cv , scoring = "precision"
49+ )
50+ assert_allclose (cv_result ["test_score" ], [0.625 , 0.6 , 0.625 ], atol = 1e-6 , rtol = 1e-6 )
4851
4952
5053@pytest .mark .parametrize ("dtype_target" , [None , object ])
@@ -53,9 +56,15 @@ def test_target_string_labels(data, dtype_target):
5356 X , y = data
5457 labels = np .array (["a" , "b" ], dtype = dtype_target )
5558 y = labels [y ]
56- ih_cv = InstanceHardnessCV (estimator = LogisticRegression ())
57- cv_result = cross_validate (LogisticRegression (), X , y , cv = ih_cv )
58- assert_allclose (cv_result ["test_score" ], [0.975 , 0.965 , 0.96 , 0.955 , 0.965 ])
59+ ih_cv = InstanceHardnessCV (estimator = LogisticRegression (), n_splits = 3 )
60+ cv_result = cross_validate (
61+ LogisticRegression (),
62+ X ,
63+ y ,
64+ cv = ih_cv ,
65+ scoring = make_scorer (precision_score , pos_label = "b" ),
66+ )
67+ assert_allclose (cv_result ["test_score" ], [0.625 , 0.6 , 0.625 ], atol = 1e-6 , rtol = 1e-6 )
5968
6069
6170@pytest .mark .parametrize ("dtype_target" , [None , object ])
@@ -68,9 +77,19 @@ def test_target_string_pos_label(data, dtype_target):
6877 X , y = data
6978 labels = np .array (["a" , "b" ], dtype = dtype_target )
7079 y = labels [y ]
71- ih_cv = InstanceHardnessCV (estimator = LogisticRegression (), pos_label = "a" )
72- cv_result = cross_validate (LogisticRegression (), X , y , cv = ih_cv )
73- assert_allclose (cv_result ["test_score" ], [0.965 , 0.975 , 0.965 , 0.955 , 0.96 ])
80+ ih_cv = InstanceHardnessCV (
81+ estimator = LogisticRegression (), pos_label = "a" , n_splits = 3
82+ )
83+ cv_result = cross_validate (
84+ LogisticRegression (),
85+ X ,
86+ y ,
87+ cv = ih_cv ,
88+ scoring = make_scorer (precision_score , pos_label = "a" ),
89+ )
90+ assert_allclose (
91+ cv_result ["test_score" ], [0.666667 , 0.666667 , 0.4 ], atol = 1e-6 , rtol = 1e-6
92+ )
7493
7594
7695@pytest .mark .parametrize ("n_splits" , [2 , 3 , 4 ])
0 commit comments