Skip to content

Commit 2685b68

Browse files
committed
Support scikit-learn 1.8+ _check_targets API change
In scikit-learn 1.8, the _check_targets() function signature changed to accept and return sample_weight as an extra value. See: scikit-learn/scikit-learn#31701 This commit adds version checking to handle both APIs. This maintains backward compatibility with scikit-learn < 1.8.
1 parent c263ef0 commit 2685b68

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

imblearn/metrics/_classification.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
from sklearn.metrics._classification import _check_targets, _prf_divide
2626
from sklearn.preprocessing import LabelEncoder
2727
from sklearn.utils._param_validation import Interval, StrOptions
28+
from sklearn.utils.fixes import parse_version
2829
from sklearn.utils.multiclass import unique_labels
2930
from sklearn.utils.validation import check_consistent_length, column_or_1d
3031

31-
from ..utils._sklearn_compat import validate_params
32+
from ..utils._sklearn_compat import sklearn_version, validate_params
3233

3334

3435
@validate_params(
@@ -166,7 +167,12 @@ def sensitivity_specificity_support(
166167
if average not in average_options and average != "binary":
167168
raise ValueError("average has to be one of " + str(average_options))
168169

169-
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
170+
if sklearn_version >= parse_version("1.8"):
171+
y_type, y_true, y_pred, sample_weight = _check_targets(
172+
y_true, y_pred, sample_weight
173+
)
174+
else:
175+
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
170176
present_labels = unique_labels(y_true, y_pred)
171177

172178
if average == "binary":
@@ -1119,11 +1125,18 @@ def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None):
11191125
>>> macro_averaged_mean_absolute_error(y_true_imbalanced, y_pred)
11201126
0.16...
11211127
"""
1122-
_, y_true, y_pred = _check_targets(y_true, y_pred)
1123-
if sample_weight is not None:
1124-
sample_weight = column_or_1d(sample_weight)
1128+
if sklearn_version >= parse_version("1.8"):
1129+
_, y_true, y_pred, sample_weight = _check_targets(
1130+
y_true, y_pred, sample_weight
1131+
)
1132+
if sample_weight is None:
1133+
sample_weight = np.ones(y_true.shape)
11251134
else:
1126-
sample_weight = np.ones(y_true.shape)
1135+
_, y_true, y_pred = _check_targets(y_true, y_pred)
1136+
if sample_weight is not None:
1137+
sample_weight = column_or_1d(sample_weight)
1138+
else:
1139+
sample_weight = np.ones(y_true.shape)
11271140
check_consistent_length(y_true, y_pred, sample_weight)
11281141
labels = unique_labels(y_true, y_pred)
11291142
mae = []

0 commit comments

Comments
 (0)