Skip to content

Commit c263ef0

Browse files
committed
Fix compatibility with scikit-learn 1.8+
scikit-learn 1.8.0+ removed the algorithm parameter from AdaBoostClassifier scikit-learn/scikit-learn#32262 The changes maintain compatibility with previous version of scikit-learn
1 parent 5a5f6d7 commit c263ef0

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

doc/ensemble.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ Several methods taking advantage of boosting have been designed.
9797
a boosting iteration :cite:`seiffert2009rusboost`::
9898

9999
>>> from imblearn.ensemble import RUSBoostClassifier
100-
>>> rusboost = RUSBoostClassifier(n_estimators=200, algorithm='SAMME.R',
101-
... random_state=0)
100+
>>> rusboost = RUSBoostClassifier(n_estimators=200, random_state=0)
102101
>>> rusboost.fit(X_train, y_train)
103102
RUSBoostClassifier(...)
104103
>>> y_pred = rusboost.predict(X_test)

examples/ensemble/plot_comparison_ensemble_classifier.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,16 @@
194194

195195
# %%
196196
from sklearn.ensemble import AdaBoostClassifier
197+
from sklearn.utils.fixes import parse_version
197198

198199
from imblearn.ensemble import EasyEnsembleClassifier, RUSBoostClassifier
200+
from imblearn.utils._sklearn_compat import sklearn_version
201+
202+
if sklearn_version < parse_version("1.6"):
203+
estimator = AdaBoostClassifier(n_estimators=10, algorithm="SAMME")
204+
else:
205+
estimator = AdaBoostClassifier(n_estimators=10)
199206

200-
estimator = AdaBoostClassifier(n_estimators=10, algorithm="SAMME")
201207
eec = EasyEnsembleClassifier(n_estimators=10, estimator=estimator)
202208
eec.fit(X_train, y_train)
203209
y_pred_eec = eec.predict(X_test)

imblearn/ensemble/_easy_ensemble.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,14 @@ def _validate_y(self, y):
226226
self._sampling_strategy = self.sampling_strategy
227227
return y_encoded
228228

229-
def _validate_estimator(self, default=AdaBoostClassifier(algorithm="SAMME")):
229+
def _validate_estimator(self, default=None):
230230
"""Check the estimator and the n_estimator attribute, set the
231231
`estimator_` attribute."""
232232
if self.estimator is not None:
233233
estimator = clone(self.estimator)
234234
else:
235+
if default is None:
236+
default = self._get_estimator()
235237
estimator = clone(default)
236238

237239
sampler = RandomUnderSampler(
@@ -279,7 +281,7 @@ def base_estimator_(self):
279281

280282
def _get_estimator(self):
281283
if self.estimator is None:
282-
if parse_version("1.4") <= sklearn_version < parse_version("1.6"):
284+
if sklearn_version < parse_version("1.6"):
283285
return AdaBoostClassifier(algorithm="SAMME")
284286
else:
285287
return AdaBoostClassifier()

0 commit comments

Comments
 (0)