Skip to content

Commit 643d0fd

Browse files
authored
Enable custom dtype arrays as return values (#1)
* Enable custom dtype arrays as return values * Version bump
1 parent eae5228 commit 643d0fd

File tree

3 files changed

+106
-22
lines changed

3 files changed

+106
-22
lines changed

stanio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
"stan_variables",
1212
]
1313

14-
__version__ = "0.3.1"
14+
__version__ = "0.4.0"

stanio/reshape.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,23 @@ class Variable:
4747
# list of nested parameters
4848
contents: List["Variable"]
4949

50+
def dtype(self, top=True):
51+
if self.type == VariableType.TUPLE:
52+
elts = [
53+
(str(i + 1), param.dtype(top=False))
54+
for i, param in enumerate(self.contents)
55+
]
56+
dtype = np.dtype(elts)
57+
elif self.type == VariableType.SCALAR:
58+
dtype = np.float64
59+
elif self.type == VariableType.COMPLEX:
60+
dtype = np.complex128
61+
62+
if top:
63+
return dtype
64+
else:
65+
return np.dtype((dtype, self.dimensions))
66+
5067
def columns(self) -> Iterable[int]:
5168
return range(self.start_idx, self.end_idx)
5269

@@ -81,7 +98,7 @@ def _extract_helper(self, src: np.ndarray, offset: int = 0):
8198
out[i, idx] = tuple(elt[i] for elt in elts)
8299
return out.reshape(-1, *self.dimensions, order="F")
83100

84-
def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
101+
def extract_reshape(self, src: np.ndarray, object=True) -> npt.NDArray[Any]:
85102
"""
86103
Given an array where the final dimension is the flattened output of a
87104
Stan model, (e.g. one row of a Stan CSV file), extract the variable
@@ -98,6 +115,10 @@ def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
98115
Indicies besides the final dimension are preserved
99116
in the output.
100117
118+
object : bool
119+
If True, the output of tuple types will be an object array,
120+
otherwise it will use custom dtypes to represent tuples.
121+
101122
Returns
102123
-------
103124
npt.NDArray[Any]
@@ -106,10 +127,14 @@ def extract_reshape(self, src: np.ndarray) -> npt.NDArray[Any]:
106127
otherwise it will have a dtype of either float64 or complex128.
107128
"""
108129
out = self._extract_helper(src)
130+
if not object:
131+
out = out.astype(self.dtype())
109132
if src.ndim > 1:
110-
return out.reshape(*src.shape[:-1], *self.dimensions, order="F")
133+
out = out.reshape(*src.shape[:-1], *self.dimensions, order="F")
111134
else:
112-
return out.squeeze(axis=0)
135+
out = out.squeeze(axis=0)
136+
137+
return out
113138

114139

115140
def _munge_first_tuple(tup: str) -> str:
@@ -194,7 +219,10 @@ def parse_header(header: str) -> Dict[str, Variable]:
194219

195220

196221
def stan_variables(
197-
parameters: Dict[str, Variable], source: npt.NDArray[np.float64]
222+
parameters: Dict[str, Variable],
223+
source: npt.NDArray[np.float64],
224+
*,
225+
object: bool = True,
198226
) -> Dict[str, npt.NDArray[Any]]:
199227
"""
200228
Given a dictionary of :class:`Variable` objects and a source array,
@@ -208,11 +236,17 @@ def stan_variables(
208236
like that returned by :func:`parse_header()`.
209237
source : npt.NDArray[np.float64]
210238
The array to extract from.
239+
object : bool
240+
If True, the output of tuple types will be an object array,
241+
otherwise it will use custom dtypes to represent tuples.
211242
212243
Returns
213244
-------
214245
Dict[str, npt.NDArray[Any]]
215246
A dictionary mapping the base name of each variable to the extracted
216247
and reshaped data.
217248
"""
218-
return {param.name: param.extract_reshape(source) for param in parameters.values()}
249+
return {
250+
param.name: param.extract_reshape(source, object=object)
251+
for param in parameters.values()
252+
}

test/test_reshape.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212

1313

1414
# see file data/rectangles/output.stan
15-
@pytest.fixture(scope="module")
16-
def rect_data():
15+
@pytest.fixture(scope="module", params=[True, False], ids=["use_object", "use_dtype"])
16+
def rect_data(request):
1717
files = [DATA / "rectangles" / f"output_{i}.csv" for i in range(1, 5)]
1818
header, data = read_csv(files)
1919
params = parse_header(header)
20-
yield stan_variables(params, data)
20+
yield stan_variables(params, data, object=request.param)
2121

2222

2323
def test_basic_shapes(rect_data):
@@ -91,43 +91,93 @@ def test_basic_values(rect_data):
9191

9292

9393
# see file data/tuples/output.stan
94-
@pytest.fixture(scope="module")
95-
def tuple_data():
94+
@pytest.fixture(scope="module", params=[True, False], ids=["use_object", "use_dtype"])
95+
def tuple_data(request):
9696
files = [DATA / "tuples" / f"output_{i}.csv" for i in range(1, 5)]
9797
header, data = read_csv(files)
9898
params = parse_header(header)
99-
yield stan_variables(params, data)
99+
yield stan_variables(params, data, object=request.param)
100100

101101

102102
def test_tuple_shapes(tuple_data):
103-
assert isinstance(tuple_data["pair"][0, 0], tuple)
104103
assert len(tuple_data["pair"][0, 0]) == 2
105104

106-
assert isinstance(tuple_data["nested"][0, 0], tuple)
107105
assert len(tuple_data["nested"][0, 0]) == 2
108-
assert isinstance(tuple_data["nested"][0, 0][1], tuple)
109106
assert len(tuple_data["nested"][0, 0][1]) == 2
110107

111108
assert tuple_data["arr_pair"].shape == (4, 1000, 2)
112-
assert isinstance(tuple_data["arr_pair"][0, 0, 0], tuple)
113109

114110
assert tuple_data["arr_very_nested"].shape == (4, 1000, 3)
111+
112+
assert tuple_data["arr_2d_pair"].shape == (4, 1000, 3, 2)
113+
114+
assert tuple_data["ultimate"].shape == (4, 1000, 2, 3)
115+
assert tuple_data["ultimate"][0, 0, 0, 0][0].shape == (2,)
116+
assert tuple_data["ultimate"][0, 0, 0, 0][0][0][1].shape == (2,)
117+
assert tuple_data["ultimate"][0, 0, 0, 0][1].shape == (4, 5)
118+
119+
120+
def check_tuple_shapes_objects(tuple_data):
121+
assert isinstance(tuple_data["pair"][0, 0], tuple)
122+
123+
assert isinstance(tuple_data["nested"][0, 0], tuple)
124+
assert isinstance(tuple_data["nested"][0, 0][1], tuple)
125+
126+
assert isinstance(tuple_data["arr_pair"][0, 0, 0], tuple)
127+
115128
assert isinstance(tuple_data["arr_very_nested"][0, 0, 0], tuple)
116129
assert isinstance(tuple_data["arr_very_nested"][0, 0, 0][0], tuple)
117130
assert isinstance(tuple_data["arr_very_nested"][0, 0, 0][0][1], tuple)
118131

119-
assert tuple_data["arr_2d_pair"].shape == (4, 1000, 3, 2)
120132
assert isinstance(tuple_data["arr_2d_pair"][0, 0, 0, 0], tuple)
121133

122-
assert tuple_data["ultimate"].shape == (4, 1000, 2, 3)
123134
assert isinstance(tuple_data["ultimate"][0, 0, 0, 0], tuple)
124-
assert tuple_data["ultimate"][0, 0, 0, 0][0].shape == (2,)
125135
assert isinstance(tuple_data["ultimate"][0, 0, 0, 0][0][0], tuple)
126-
assert tuple_data["ultimate"][0, 0, 0, 0][0][0][1].shape == (2,)
127-
assert tuple_data["ultimate"][0, 0, 0, 0][1].shape == (4, 5)
136+
137+
138+
def check_tuple_shapes_custom_dtypes(tuple_data):
139+
for value in tuple_data.values():
140+
assert not value.dtype.hasobject
141+
142+
pair_dtype = np.dtype([("1", "f8"), ("2", "f8")])
143+
assert tuple_data["pair"].dtype == pair_dtype
144+
145+
nested_dtype = np.dtype([("1", "f8"), ("2", [("1", "f8"), ("2", "c16")])])
146+
assert tuple_data["nested"].dtype == nested_dtype
147+
assert tuple_data["nested"][0, 0][1].dtype == nested_dtype[1]
148+
149+
assert tuple_data["arr_pair"].dtype == pair_dtype
150+
151+
very_nested_dtype = np.dtype(
152+
[
153+
("1", nested_dtype),
154+
("2", "f8"),
155+
]
156+
)
157+
assert tuple_data["arr_very_nested"].dtype == very_nested_dtype
158+
assert tuple_data["arr_very_nested"][0, 0, 0][0].dtype == nested_dtype
159+
assert tuple_data["arr_very_nested"][0, 0, 0][0][1].dtype == nested_dtype[1]
160+
161+
ultimate_dtype = np.dtype(
162+
[
163+
("1", ([("1", "f8"), ("2", "(2,)f8")], (2,))),
164+
("2", "(4,5)f8"),
165+
]
166+
)
167+
assert tuple_data["ultimate"].dtype == ultimate_dtype
168+
169+
170+
def test_tuple_dtypes(tuple_data):
171+
if isinstance(tuple_data["pair"][0, 0], tuple):
172+
check_tuple_shapes_objects(tuple_data)
173+
else:
174+
check_tuple_shapes_custom_dtypes(tuple_data)
128175

129176

130177
def assert_tuple_equal(t1, t2):
178+
if hasattr(t1, "dtype") and t1.dtype.kind == "V":
179+
t1 = t1.tolist()
180+
131181
assert len(t1) == len(t2)
132182
for x, y in zip(t1, t2):
133183
if isinstance(x, tuple):
@@ -140,7 +190,7 @@ def check_tuples(tuple_data, chain, draw):
140190
base = tuple_data["base"][chain, draw]
141191
base_i = tuple_data["base_i"][chain, draw]
142192
pair_exp = (base, 2 * base)
143-
np.testing.assert_almost_equal(tuple_data["pair"][chain, draw], pair_exp)
193+
assert_tuple_equal(tuple_data["pair"][chain, draw], pair_exp)
144194
nested_exp = (base * 3, (base_i, 4j * base))
145195
assert_tuple_equal(tuple_data["nested"][chain, draw], nested_exp)
146196

0 commit comments

Comments
 (0)