Skip to content

Commit 13e0f1b

Browse files
committed
Fix np.matrix complex types
1 parent 799f296 commit 13e0f1b

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

stanio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
"stan_variables",
1212
]
1313

14-
__version__ = "0.5.0"
14+
__version__ = "0.5.1"

stanio/json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def process_value(val: Any) -> Any:
4646
if numpy_val.dtype.kind in "iuf":
4747
return numpy_val.tolist()
4848
if numpy_val.dtype.kind == "c":
49-
return np.stack([numpy_val.real, numpy_val.imag], axis=-1).tolist()
49+
return np.stack([np.asarray(numpy_val.real), np.asarray(numpy_val.imag)], axis=-1).tolist()
5050
if numpy_val.dtype.kind == "b":
5151
return numpy_val.astype(int).tolist()
5252

test/test_json.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,15 @@ def test_complex_numbers_np(TMPDIR) -> None:
161161
compare_before_after(file_complex, dict_complex, dict_complex_exp)
162162

163163

164+
def test_complex_numbers_np_matrix(TMPDIR) -> None:
165+
a_raw = np.random.rand(21, 21, 2)
166+
a = np.matrix(a_raw[:, :, 0] + 1j * a_raw[:, :, 1])
167+
dict_complex = {"a": a}
168+
dict_complex_exp = {"a": a_raw}
169+
file_complex = os.path.join(TMPDIR, "complex_np_mat.json")
170+
compare_before_after(file_complex, dict_complex, dict_complex_exp)
171+
172+
164173
def test_tuples(TMPDIR) -> None:
165174
dict_tuples = {
166175
"a": (1, 2, 3),

0 commit comments

Comments
 (0)