1010
1111import pytest
1212from sklearn .datasets import make_classification
13- from sklearn .linear_model import LogisticRegression
1413from sklearn .utils ._testing import (
1514 _get_func_name ,
1615 check_docstring_parameters ,
2423
2524import imblearn
2625from imblearn .base import is_sampler
27- from imblearn .utils ._sklearn_compat import _construct_instances
26+ from imblearn .under_sampling import NearMiss
27+ from imblearn .utils ._test_common .instance_generator import _tested_estimators
2828from imblearn .utils .estimator_checks import _set_checking_parameters
29- from imblearn .utils .testing import all_estimators
3029
3130# walk_packages() ignores DeprecationWarnings, now we need to ignore
3231# FutureWarnings
4342 )
4443
4544# functions to ignore args / docstring of
46- _DOCSTRING_IGNORES = [
47- "RUSBoostClassifier" , # TODO remove after releasing scikit-learn 1.0.1
48- "ValueDifferenceMetric" ,
49- ]
45+ _DOCSTRING_IGNORES = ["ValueDifferenceMetric" ]
46+ _IGNORE_ATTRIBUTES = {
47+ NearMiss : [ "nn_ver3_" ] ,
48+ }
5049
5150# Methods where y param should be ignored if y=None by default
5251_METHODS_IGNORE_NONE_Y = [
@@ -159,28 +158,19 @@ def test_tabs():
159158 )
160159
161160
162- def _construct_compose_pipeline_instance (Estimator ):
163- # Minimal / degenerate instances: only useful to test the docstrings.
164- if Estimator .__name__ == "Pipeline" :
165- return Estimator (steps = [("clf" , LogisticRegression ())])
166-
167-
168- @pytest .mark .parametrize ("name, Estimator" , all_estimators ())
169- def test_fit_docstring_attributes (name , Estimator ):
161+ @pytest .mark .parametrize ("estimator" , list (_tested_estimators ()))
162+ def test_fit_docstring_attributes (estimator ):
170163 pytest .importorskip ("numpydoc" )
171164 from numpydoc import docscrape
172165
166+ Estimator = estimator .__class__
173167 if Estimator .__name__ in _DOCSTRING_IGNORES :
174168 return
175169
176170 doc = docscrape .ClassDoc (Estimator )
177171 attributes = doc ["Attributes" ]
178172
179- if Estimator .__name__ == "Pipeline" :
180- est = _construct_compose_pipeline_instance (Estimator )
181- else :
182- est = next (_construct_instances (Estimator ))
183- _set_checking_parameters (est )
173+ _set_checking_parameters (estimator )
184174
185175 X , y = make_classification (
186176 n_samples = 20 ,
@@ -190,16 +180,16 @@ def test_fit_docstring_attributes(name, Estimator):
190180 random_state = 2 ,
191181 )
192182
193- y = _enforce_estimator_tags_y (est , y )
194- X = _enforce_estimator_tags_X (est , X )
183+ y = _enforce_estimator_tags_y (estimator , y )
184+ X = _enforce_estimator_tags_X (estimator , X )
195185
196- if "oob_score" in est .get_params ():
197- est .set_params (bootstrap = True , oob_score = True )
186+ if "oob_score" in estimator .get_params ():
187+ estimator .set_params (bootstrap = True , oob_score = True )
198188
199- if is_sampler (est ):
200- est .fit_resample (X , y )
189+ if is_sampler (estimator ):
190+ estimator .fit_resample (X , y )
201191 else :
202- est .fit (X , y )
192+ estimator .fit (X , y )
203193
204194 skipped_attributes = set (
205195 [
@@ -218,9 +208,11 @@ def test_fit_docstring_attributes(name, Estimator):
218208 continue
219209 # ignore deprecation warnings
220210 with ignore_warnings (category = FutureWarning ):
221- assert hasattr (est , attr .name )
211+ if attr .name in _IGNORE_ATTRIBUTES .get (Estimator , []):
212+ continue
213+ assert hasattr (estimator , attr .name )
222214
223- fit_attr = _get_all_fitted_attributes (est )
215+ fit_attr = _get_all_fitted_attributes (estimator )
224216 fit_attr_names = [attr .name for attr in attributes ]
225217 undocumented_attrs = set (fit_attr ).difference (fit_attr_names )
226218 undocumented_attrs = set (undocumented_attrs ).difference (skipped_attributes )
0 commit comments