Skip to content

Commit 3cf73bf

Browse files
authored
Fix sklearn compatibility issues (#148)
1 parent 6802b19 commit 3cf73bf

File tree

14 files changed

+6098
-2889
lines changed

14 files changed

+6098
-2889
lines changed

examples/2020ECAHM-scikit-downscale.ipynb

Lines changed: 221 additions & 2583 deletions
Large diffs are not rendered by default.

examples/bcsd_example.ipynb

Lines changed: 2779 additions & 27 deletions
Large diffs are not rendered by default.

examples/gard_example.ipynb

Lines changed: 2709 additions & 16 deletions
Large diffs are not rendered by default.

examples/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,34 @@ def get_sample_data(
5151
return (
5252
data['uas.hist.CanESM2.CRCM5-UQAM.day.NAM-44i.raw']
5353
.sel(lat=40.25, lon=-109.2, method='nearest')
54-
.to_dataset()
54+
.to_dataset()[['uas']]
55+
.convert_calendar('standard')
5556
.squeeze()
5657
.to_dataframe()[['uas']]
58+
.resample('1d')
59+
.first()
5760
)
5861
elif kind == 'wind-obs':
5962
return (
6063
data['uas.gridMET.NAM-44i']
6164
.sel(lat=40.25, lon=-109.2, method='nearest')
62-
.to_dataset()
65+
.to_dataset()[['uas']]
66+
.convert_calendar('standard')
6367
.squeeze()
6468
.to_dataframe()[['uas']]
69+
.resample('1d')
70+
.first()
6571
)
6672
elif kind == 'wind-rcp':
6773
return (
6874
data['uas.rcp85.CanESM2.CRCM5-UQAM.day.NAM-44i.raw']
6975
.sel(lat=40.25, lon=-109.2, method='nearest')
70-
.to_dataset()
76+
.to_dataset()[['uas']]
77+
.convert_calendar('standard')
7178
.squeeze()
7279
.to_dataframe()[['uas']]
80+
.resample('1d')
81+
.first()
7382
)
7483
else:
7584
raise ValueError(kind)

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,18 @@
5858
"dask>=2024.11",
5959
"jupyterlab>=4.3",
6060
"matplotlib>=3.9.4",
61+
"mlinsights>=0.5.2",
6162
"netcdf4>=1.7",
6263
"numpydoc>=1.9.0",
64+
"probscale>=0.2.5",
6365
"s3fs>=2025.10.0",
6466
"scipy>=1.14",
67+
"seaborn>=0.13.2",
6568
"sphinx-gallery>=0.19.0",
6669
"sphinx-rtd-theme>=3.0.2",
6770
"sphinx>=8.0",
6871
"zarr>=2.18.2",
69-
]
72+
]
7073

7174
[tool.ruff]
7275
builtins = ["ellipsis"]

skdownscale/pointwise_models/arrm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def arrm_breakpoints(X, y, window_width, max_breakpoints):
6969

7070
# lower half of distribution
7171
# start at 0.4 or the first breakpoint
72-
start = min(breakpoints) if breakpoints else start
72+
start = min(breakpoints, default=start)
7373
# likely need this to avoid recomputing r2 and picking the same breakpoint twice
7474
start -= (min_width // 2) + 1
7575

@@ -82,7 +82,7 @@ def arrm_breakpoints(X, y, window_width, max_breakpoints):
8282
r2[mid] = np.corrcoef(X[s], y[s])[0, 1] ** 2
8383

8484
# find the last three breakpoints
85-
for bp in range(max_breakpoints // 2): # this means max_breakpoints must always be even
85+
for _ in range(max_breakpoints // 2): # this means max_breakpoints must always be even
8686
mind = np.argmin(r2[:start]) # find minimum r2, only look at the first part of the array
8787
breakpoints.append(mind) # breakpoint is in the center of the window
8888

skdownscale/pointwise_models/base.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,45 @@ def _check_array(self, array, **kwargs):
2929

3030
return array
3131

32-
def _validate_data(self, X, y=None, reset=True, validate_separately=False, **check_params):
32+
def _check_n_features(self, X, reset):
33+
"""Check and set n_features_in_ attribute.
34+
35+
Parameters
36+
----------
37+
X : array-like
38+
Input data
39+
reset : bool
40+
Whether to reset n_features_in_ or check consistency
41+
"""
42+
n_features = X.shape[1] if hasattr(X, 'shape') and len(X.shape) > 1 else 1
43+
44+
if reset:
45+
self.n_features_in_ = n_features
46+
elif hasattr(self, 'n_features_in_'):
47+
if self.n_features_in_ != n_features:
48+
raise ValueError(
49+
f'X has {n_features} features, but {self.__class__.__name__} '
50+
f'was fitted with {self.n_features_in_} features.'
51+
)
52+
53+
def __sklearn_tags__(self):
54+
"""Get estimator tags for sklearn 1.6+.
55+
56+
Returns
57+
-------
58+
tags : Tags
59+
Tags object with estimator metadata.
60+
"""
61+
from dataclasses import replace
62+
63+
tags = super().__sklearn_tags__()
64+
# Update target_tags to indicate y is not required by default
65+
tags = replace(tags, target_tags=replace(tags.target_tags, required=False))
66+
return tags
67+
68+
def _validate_data(
69+
self, X, y=None, reset: bool = True, validate_separately: bool = False, **check_params: dict
70+
):
3371
"""Validate input data and set or check the `n_features_in_` attribute.
3472
3573
Parameters
@@ -60,7 +98,8 @@ def _validate_data(self, X, y=None, reset=True, validate_separately=False, **che
6098
"""
6199

62100
if y is None:
63-
if self._get_tags()['requires_y']:
101+
tags = self.__sklearn_tags__()
102+
if tags.target_tags.required:
64103
raise ValueError(
65104
f'This {self.__class__.__name__} estimator '
66105
f'requires y to be passed, but the target y is None.'
@@ -80,7 +119,6 @@ def _validate_data(self, X, y=None, reset=True, validate_separately=False, **che
80119
X, y = self._check_X_y(X, y, **check_params)
81120
out = X, y
82121

83-
# TO-DO: add check_n_features attribute
84122
if check_params.get('ensure_2d', True):
85123
self._check_n_features(X, reset=reset)
86124

skdownscale/pointwise_models/bcsd.py

Lines changed: 20 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,7 @@ def predict(self, X):
162162
Xqm = self._qm_transform_by_group(self._create_groups(X, climate_trend=True))
163163

164164
# calculate the anomalies as a ratio of the training data
165-
if self.return_anoms:
166-
return self._calc_ratio_anoms(Xqm, self.y_climo_)
167-
else:
168-
return Xqm
165+
return self._calc_ratio_anoms(Xqm, self.y_climo_) if self.return_anoms else Xqm
169166

170167
def _calc_ratio_anoms(self, obj, climatology, climate_trend=False):
171168
"""helper function for dividing day groups by climatology"""
@@ -181,29 +178,13 @@ def _calc_ratio_anoms(self, obj, climatology, climate_trend=False):
181178

182179
return result
183180

184-
def _more_tags(self):
185-
return {
186-
'_xfail_checks': {
187-
'check_estimators_dtypes': 'BCSD only suppers 1 feature',
188-
'check_dtype_object': 'BCSD only suppers 1 feature',
189-
'check_fit_score_takes_y': 'BCSD only suppers 1 feature',
190-
'check_estimators_fit_returns_self': 'BCSD only suppers 1 feature',
191-
'check_estimators_fit_returns_self(readonly_memmap=True)': 'BCSD only suppers 1 feature',
192-
'check_pipeline_consistency': 'BCSD only suppers 1 feature',
193-
'check_estimators_nan_inf': 'BCSD only suppers 1 feature',
194-
'check_estimators_overwrite_params': 'BCSD only suppers 1 feature',
195-
'check_estimators_pickle': 'BCSD only suppers 1 feature',
196-
'check_fit2d_predict1d': 'BCSD only suppers 1 feature',
197-
'check_methods_subset_invariance': 'BCSD only suppers 1 feature',
198-
'check_fit2d_1sample': 'BCSD only suppers 1 feature',
199-
'check_dict_unchanged': 'BCSD only suppers 1 feature',
200-
'check_dont_overwrite_parameters': 'BCSD only suppers 1 feature',
201-
'check_fit_idempotent': 'BCSD only suppers 1 feature',
202-
'check_n_features_in': 'BCSD only suppers 1 feature',
203-
'check_fit_check_is_fitted': 'BCSD only suppers 1 feature',
204-
'check_methods_sample_order_invariance': 'temporal order matters',
205-
},
206-
}
181+
def __sklearn_tags__(self):
182+
from dataclasses import replace
183+
184+
tags = super().__sklearn_tags__()
185+
# Skip tests - only supports 1 feature, temporal order matters
186+
tags = replace(tags, _skip_test='BCSD only supports 1 feature and temporal order matters')
187+
return tags
207188

208189

209190
class BcsdTemperature(BcsdBase):
@@ -283,38 +264,20 @@ def rolling_func(x):
283264

284265
def _remove_climatology(self, obj, climatology, climate_trend=False):
285266
"""helper function to remove climatologies"""
286-
dfs = []
287-
for key, group in self._create_groups(obj, climate_trend):
288-
if self.timestep == 'monthly':
289-
dfs.append(group - climatology.loc[key].values)
290-
elif self.timestep == 'daily':
291-
dfs.append(group - climatology.loc[key].values)
292-
267+
dfs = [
268+
group - climatology.loc[key].values
269+
for key, group in self._create_groups(obj, climate_trend)
270+
if self.timestep in ['monthly', 'daily']
271+
]
293272
result = pd.concat(dfs).sort_index()
294273
if obj.shape != result.shape:
295274
raise ValueError('shape of climo is not equal to input array')
296275
return result
297276

298-
def _more_tags(self):
299-
return {
300-
'_xfail_checks': {
301-
'check_estimators_dtypes': 'BCSD only suppers 1 feature',
302-
'check_fit_score_takes_y': 'BCSD only suppers 1 feature',
303-
'check_estimators_fit_returns_self': 'BCSD only suppers 1 feature',
304-
'check_estimators_fit_returns_self(readonly_memmap=True)': 'BCSD only suppers 1 feature',
305-
'check_dtype_object': 'BCSD only suppers 1 feature',
306-
'check_pipeline_consistency': 'BCSD only suppers 1 feature',
307-
'check_estimators_nan_inf': 'BCSD only suppers 1 feature',
308-
'check_estimators_overwrite_params': 'BCSD only suppers 1 feature',
309-
'check_estimators_pickle': 'BCSD only suppers 1 feature',
310-
'check_fit2d_predict1d': 'BCSD only suppers 1 feature',
311-
'check_methods_subset_invariance': 'BCSD only suppers 1 feature',
312-
'check_fit2d_1sample': 'BCSD only suppers 1 feature',
313-
'check_dict_unchanged': 'BCSD only suppers 1 feature',
314-
'check_dont_overwrite_parameters': 'BCSD only suppers 1 feature',
315-
'check_fit_idempotent': 'BCSD only suppers 1 feature',
316-
'check_n_features_in': 'BCSD only suppers 1 feature',
317-
'check_fit_check_is_fitted': 'BCSD only suppers 1 feature',
318-
'check_methods_sample_order_invariance': 'temporal order matters',
319-
},
320-
}
277+
def __sklearn_tags__(self):
278+
from dataclasses import replace
279+
280+
tags = super().__sklearn_tags__()
281+
# Skip tests - only supports 1 feature, temporal order matters
282+
tags = replace(tags, _skip_test='BCSD only supports 1 feature and temporal order matters')
283+
return tags

0 commit comments

Comments
 (0)