Skip to content

Commit 1f5e5f2

Browse files
authored
[ENH] refactor datatypes mtypes - example fixtures (#458)
This PR refactors the data type specifications and converters to classes. Related: sktime/sktime#3512, related to sktime/sktime#2957. Contains: * a base class for datatype examples, `BaseExample`, to replace the more ad-hoc dictionary design * a complete refactor of the `Table` and `Proba` mtype submodules to this interface * a full refactor of the public framework module with `get_example` logic, in `datatypes`, to allow extensibility with this design Partial mirror in `skpro` of sktime/sktime#6033
1 parent 31da07d commit 1f5e5f2

File tree

7 files changed

+345
-162
lines changed

7 files changed

+345
-162
lines changed

skpro/datatypes/_base/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Base module for datatypes."""
22

3-
from skpro.datatypes._base._base import BaseConverter, BaseDatatype
3+
from skpro.datatypes._base._base import BaseConverter, BaseDatatype, BaseExample
44

5-
__all__ = ["BaseConverter", "BaseDatatype"]
5+
__all__ = ["BaseConverter", "BaseDatatype", "BaseExample"]

skpro/datatypes/_base/_base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,43 @@ def _get_key(self):
328328
return (mtype_from, mtype_to, scitype)
329329

330330

331+
class BaseExample(BaseObject):
332+
"""Base class for Example fixtures used in tests and get_examples."""
333+
334+
_tags = {
335+
"object_type": "datatype_example",
336+
"scitype": None,
337+
"mtype": None,
338+
"python_version": None,
339+
"python_dependencies": None,
340+
"index": None, # integer index of the example to match with other mtypes
341+
"lossy": False, # whether the example is lossy
342+
}
343+
344+
def __init__(self):
345+
super().__init__()
346+
347+
def _get_key(self):
348+
"""Get unique dictionary key corresponding to self.
349+
350+
Private function, used in collecting a dictionary of examples.
351+
"""
352+
mtype = self.get_class_tag("mtype")
353+
scitype = self.get_class_tag("scitype")
354+
index = self.get_class_tag("index")
355+
return (mtype, scitype, index)
356+
357+
def build(self):
358+
"""Build example.
359+
360+
Returns
361+
-------
362+
obj : any
363+
Example object.
364+
"""
365+
raise NotImplementedError
366+
367+
331368
def _coerce_str_to_cls(cls_or_str):
332369
"""Get class from string.
333370

skpro/datatypes/_examples.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
e.g., metadata such as column names are missing
1414
"""
1515

16+
from functools import lru_cache
17+
1618
from skpro.datatypes._registry import mtype_to_scitype
1719

1820
__author__ = ["fkiraly"]
@@ -21,29 +23,36 @@
2123
"get_examples",
2224
]
2325

24-
from skpro.datatypes._proba import (
25-
example_dict_lossy_Proba,
26-
example_dict_metadata_Proba,
27-
example_dict_Proba,
28-
)
29-
from skpro.datatypes._table import (
30-
example_dict_lossy_Table,
31-
example_dict_metadata_Table,
32-
example_dict_Table,
33-
)
3426

35-
# pool example_dict-s
36-
example_dict = dict()
37-
example_dict.update(example_dict_Proba)
38-
example_dict.update(example_dict_Table)
27+
@lru_cache(maxsize=1)
28+
def generate_example_dicts(soft_deps="present"):
29+
"""Generate example dicts using lookup."""
30+
from skbase.utils.dependencies import _check_estimator_deps
31+
32+
from skpro.datatypes._base import BaseExample
33+
from skpro.utils.retrieval import _all_classes
34+
35+
classes = _all_classes("skpro.datatypes")
36+
classes = [x[1] for x in classes]
37+
classes = [x for x in classes if issubclass(x, BaseExample)]
38+
classes = [x for x in classes if not x.__name__.startswith("Base")]
3939

40-
example_dict_lossy = dict()
41-
example_dict_lossy.update(example_dict_lossy_Proba)
42-
example_dict_lossy.update(example_dict_lossy_Table)
40+
# subset only to data types with soft dependencies present
41+
if soft_deps == "present":
42+
classes = [x for x in classes if _check_estimator_deps(x, severity="none")]
4343

44-
example_dict_metadata = dict()
45-
example_dict_metadata.update(example_dict_metadata_Proba)
46-
example_dict_metadata.update(example_dict_metadata_Table)
44+
example_dict = dict()
45+
example_dict_lossy = dict()
46+
example_dict_metadata = dict()
47+
for cls in classes:
48+
k = cls()
49+
key = k._get_key()
50+
key_meta = (key[1], key[2])
51+
example_dict[key] = k
52+
example_dict_lossy[key] = k.get_class_tags().get("lossy", False)
53+
example_dict_metadata[key_meta] = k.get_class_tags().get("metadata", {})
54+
55+
return example_dict, example_dict_lossy, example_dict_metadata
4756

4857

4958
def get_examples(
@@ -79,6 +88,8 @@ def get_examples(
7988
if as_scitype is None:
8089
as_scitype = mtype_to_scitype(mtype)
8190

91+
example_dict, example_dict_lossy, example_dict_metadata = generate_example_dicts()
92+
8293
# retrieve all keys that match the query
8394
exkeys = example_dict.keys()
8495
keys = [k for k in exkeys if k[0] == mtype and k[1] == as_scitype]
@@ -88,14 +99,14 @@ def get_examples(
8899

89100
for k in keys:
90101
if return_lossy:
91-
fixtures[k[2]] = (example_dict.get(k), example_dict_lossy.get(k))
102+
fixtures[k[2]] = (example_dict.get(k).build(), example_dict_lossy.get(k))
92103
elif return_metadata:
93104
fixtures[k[2]] = (
94-
example_dict.get(k),
105+
example_dict.get(k).build(),
95106
example_dict_lossy.get(k),
96107
example_dict_metadata.get((k[1], k[2])),
97108
)
98109
else:
99-
fixtures[k[2]] = example_dict.get(k)
110+
fixtures[k[2]] = example_dict.get(k).build()
100111

101112
return fixtures

skpro/datatypes/_proba/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,11 @@
22

33
from skpro.datatypes._proba._check import check_dict as check_dict_Proba
44
from skpro.datatypes._proba._convert import convert_dict as convert_dict_Proba
5-
from skpro.datatypes._proba._examples import example_dict as example_dict_Proba
6-
from skpro.datatypes._proba._examples import (
7-
example_dict_lossy as example_dict_lossy_Proba,
8-
)
9-
from skpro.datatypes._proba._examples import (
10-
example_dict_metadata as example_dict_metadata_Proba,
11-
)
125
from skpro.datatypes._proba._registry import MTYPE_LIST_PROBA, MTYPE_REGISTER_PROBA
136

147
__all__ = [
158
"check_dict_Proba",
169
"convert_dict_Proba",
1710
"MTYPE_LIST_PROBA",
1811
"MTYPE_REGISTER_PROBA",
19-
"example_dict_Proba",
20-
"example_dict_lossy_Proba",
21-
"example_dict_metadata_Proba",
2212
]

skpro/datatypes/_proba/_examples.py

Lines changed: 95 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -31,64 +31,112 @@
3131
import numpy as np
3232
import pandas as pd
3333

34-
example_dict = dict()
35-
example_dict_lossy = dict()
36-
example_dict_metadata = dict()
34+
from skpro.datatypes._base import BaseExample
3735

3836
###
3937
# example 0: univariate
4038

41-
pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
42-
pred_q.columns = pd.MultiIndex.from_product([["foo"], [0.2, 0.6]])
4339

44-
# we need to use this due to numerical inaccuracies from the binary based representation
45-
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)
40+
class _ProbaUniv(BaseExample):
41+
_tags = {
42+
"scitype": "Proba",
43+
"index": 0,
44+
"metadata": {
45+
"is_univariate": True,
46+
"is_empty": False,
47+
"has_nans": False,
48+
},
49+
}
4650

47-
example_dict[("pred_quantiles", "Proba", 0)] = pred_q
48-
example_dict_lossy[("pred_quantiles", "Proba", 0)] = False
4951

50-
pred_int = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
51-
pred_int.columns = pd.MultiIndex.from_tuples(
52-
[("foo", 0.6, "lower"), ("foo", pseudo_0_2, "upper")]
53-
)
52+
class _ProbaUnivPredQ(_ProbaUniv):
53+
_tags = {
54+
"mtype": "pred_quantiles",
55+
"python_dependencies": None,
56+
"lossy": False,
57+
}
5458

55-
example_dict[("pred_interval", "Proba", 0)] = pred_int
56-
example_dict_lossy[("pred_interval", "Proba", 0)] = False
59+
def build(self):
60+
pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
61+
pred_q.columns = pd.MultiIndex.from_product([["foo"], [0.2, 0.6]])
5762

63+
return pred_q
64+
65+
66+
class _ProbaUnivPredInt(_ProbaUniv):
67+
_tags = {
68+
"mtype": "pred_interval",
69+
"python_dependencies": None,
70+
"lossy": False,
71+
}
72+
73+
def build(self):
74+
# we need to use this due to numerical inaccuracies
75+
# from the binary based representation
76+
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)
77+
78+
pred_int = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
79+
pred_int.columns = pd.MultiIndex.from_tuples(
80+
[("foo", 0.6, "lower"), ("foo", pseudo_0_2, "upper")]
81+
)
82+
83+
return pred_int
5884

59-
example_dict_metadata[("Proba", 0)] = {
60-
"is_univariate": True,
61-
"is_empty": False,
62-
"has_nans": False,
63-
}
6485

6586
###
6687
# example 1: multi
6788

68-
pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]})
69-
pred_q.columns = pd.MultiIndex.from_product([["foo", "bar"], [0.2, 0.6]])
70-
71-
example_dict[("pred_quantiles", "Proba", 1)] = pred_q
72-
example_dict_lossy[("pred_quantiles", "Proba", 1)] = False
73-
74-
pred_int = pd.DataFrame(
75-
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
76-
)
77-
pred_int.columns = pd.MultiIndex.from_tuples(
78-
[
79-
("foo", 0.6, "lower"),
80-
("foo", pseudo_0_2, "upper"),
81-
("bar", 0.6, "lower"),
82-
("bar", pseudo_0_2, "upper"),
83-
]
84-
)
85-
86-
example_dict[("pred_interval", "Proba", 1)] = pred_int
87-
example_dict_lossy[("pred_interval", "Proba", 1)] = False
88-
89-
90-
example_dict_metadata[("Proba", 1)] = {
91-
"is_univariate": False,
92-
"is_empty": False,
93-
"has_nans": False,
94-
}
89+
90+
class _ProbaMulti(BaseExample):
91+
_tags = {
92+
"scitype": "Proba",
93+
"index": 1,
94+
"metadata": {
95+
"is_univariate": False,
96+
"is_empty": False,
97+
"has_nans": False,
98+
},
99+
}
100+
101+
102+
class _ProbaMultiPredQ(_ProbaMulti):
103+
_tags = {
104+
"mtype": "pred_quantiles",
105+
"python_dependencies": None,
106+
"lossy": False,
107+
}
108+
109+
def build(self):
110+
pred_q = pd.DataFrame(
111+
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
112+
)
113+
pred_q.columns = pd.MultiIndex.from_product([["foo", "bar"], [0.2, 0.6]])
114+
115+
return pred_q
116+
117+
118+
class _ProbaMultiPredInt(_ProbaMulti):
119+
_tags = {
120+
"mtype": "pred_interval",
121+
"python_dependencies": None,
122+
"lossy": False,
123+
}
124+
125+
def build(self):
126+
# we need to use this due to numerical inaccuracies
127+
# from the binary based representation
128+
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)
129+
130+
pred_int = pd.DataFrame(
131+
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
132+
)
133+
pred_int.columns = pd.MultiIndex.from_tuples(
134+
[
135+
("foo", 0.6, "lower"),
136+
("foo", pseudo_0_2, "upper"),
137+
("bar", 0.6, "lower"),
138+
("bar", pseudo_0_2, "upper"),
139+
]
140+
)
141+
142+
return pred_int

skpro/datatypes/_table/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,10 @@
11
"""Module exports: Series type checkers, converters and mtype inference."""
22

33
from skpro.datatypes._table._convert import convert_dict as convert_dict_Table
4-
from skpro.datatypes._table._examples import example_dict as example_dict_Table
5-
from skpro.datatypes._table._examples import (
6-
example_dict_lossy as example_dict_lossy_Table,
7-
)
8-
from skpro.datatypes._table._examples import (
9-
example_dict_metadata as example_dict_metadata_Table,
10-
)
114
from skpro.datatypes._table._registry import MTYPE_LIST_TABLE, MTYPE_REGISTER_TABLE
125

136
__all__ = [
147
"convert_dict_Table",
158
"MTYPE_LIST_TABLE",
169
"MTYPE_REGISTER_TABLE",
17-
"example_dict_Table",
18-
"example_dict_lossy_Table",
19-
"example_dict_metadata_Table",
2010
]

0 commit comments

Comments
 (0)