|
25 | 25 | from sklearn.metrics._classification import _check_targets, _prf_divide |
26 | 26 | from sklearn.preprocessing import LabelEncoder |
27 | 27 | from sklearn.utils._param_validation import Interval, StrOptions |
| 28 | +from sklearn.utils.fixes import parse_version |
28 | 29 | from sklearn.utils.multiclass import unique_labels |
29 | 30 | from sklearn.utils.validation import check_consistent_length, column_or_1d |
30 | 31 |
|
31 | | -from ..utils._sklearn_compat import validate_params |
| 32 | +from ..utils._sklearn_compat import sklearn_version, validate_params |
32 | 33 |
|
33 | 34 |
|
34 | 35 | @validate_params( |
@@ -166,7 +167,12 @@ def sensitivity_specificity_support( |
166 | 167 | if average not in average_options and average != "binary": |
167 | 168 | raise ValueError("average has to be one of " + str(average_options)) |
168 | 169 |
|
169 | | - y_type, y_true, y_pred = _check_targets(y_true, y_pred) |
| 170 | + if sklearn_version >= parse_version("1.8"): |
| 171 | + y_type, y_true, y_pred, sample_weight = _check_targets( |
| 172 | + y_true, y_pred, sample_weight |
| 173 | + ) |
| 174 | + else: |
| 175 | + y_type, y_true, y_pred = _check_targets(y_true, y_pred) |
170 | 176 | present_labels = unique_labels(y_true, y_pred) |
171 | 177 |
|
172 | 178 | if average == "binary": |
@@ -1119,11 +1125,18 @@ def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None): |
1119 | 1125 | >>> macro_averaged_mean_absolute_error(y_true_imbalanced, y_pred) |
1120 | 1126 | 0.16... |
1121 | 1127 | """ |
1122 | | - _, y_true, y_pred = _check_targets(y_true, y_pred) |
1123 | | - if sample_weight is not None: |
1124 | | - sample_weight = column_or_1d(sample_weight) |
| 1128 | + if sklearn_version >= parse_version("1.8"): |
| 1129 | + _, y_true, y_pred, sample_weight = _check_targets( |
| 1130 | + y_true, y_pred, sample_weight |
| 1131 | + ) |
| 1132 | + if sample_weight is None: |
| 1133 | + sample_weight = np.ones(y_true.shape) |
1125 | 1134 | else: |
1126 | | - sample_weight = np.ones(y_true.shape) |
| 1135 | + _, y_true, y_pred = _check_targets(y_true, y_pred) |
| 1136 | + if sample_weight is not None: |
| 1137 | + sample_weight = column_or_1d(sample_weight) |
| 1138 | + else: |
| 1139 | + sample_weight = np.ones(y_true.shape) |
1127 | 1140 | check_consistent_length(y_true, y_pred, sample_weight) |
1128 | 1141 | labels = unique_labels(y_true, y_pred) |
1129 | 1142 | mae = [] |
|
0 commit comments