Skip to content

Commit e584bc2

Browse files
authored
Speed improvements (#4)
* Add fast paths for numpy numeric arrays * Optional: ujson dependency
1 parent 6554807 commit e584bc2

File tree

4 files changed

+52
-12
lines changed

4 files changed

+52
-12
lines changed

.github/workflows/test.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ on:
1111

1212
jobs:
1313
required:
14-
name: "${{matrix.os}} / ${{matrix.python-version}} / numpy_nightly: ${{matrix.numpy_nightly}}"
14+
name: "${{matrix.os}} / ${{matrix.python-version}} / ujson: ${{matrix.ujson}}"
1515
runs-on: ${{matrix.os}}
1616
strategy:
1717
matrix:
1818
os: [ubuntu-latest]
1919
python-version: [3.8, "3.12"]
20-
numpy_nightly: [false, true]
20+
ujson: [false, true]
2121
steps:
2222
- name: Check out github
2323
uses: actions/checkout@v4
@@ -35,10 +35,10 @@ jobs:
3535
run: |
3636
python -m pip install .[test]
3737
38-
- name: Install numpy
39-
if: ${{ matrix.numpy_nightly }}
38+
- name: Install ujson
39+
if: ${{ matrix.ujson }}
4040
run: |
41-
pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple -U numpy
41+
pip install -U ujson
4242
4343
- name: Run tests
4444
run: |

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ test = [
3333
"pytest",
3434
"pytest-cov",
3535
]
36+
ujson = [
37+
"ujson>=5.5.0"
38+
]
3639

3740
[tool.isort]
3841
profile = "black"

stanio/json.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
"""
22
Utilities for writing Stan Json files
33
"""
4-
import json
4+
try:
5+
import ujson as json
6+
7+
uj_version = tuple(map(int, json.__version__.split(".")))
8+
if uj_version < (5, 5, 0):
9+
raise ImportError("ujson version too old")
10+
UJSON_AVAILABLE = True
11+
except:
12+
UJSON_AVAILABLE = False
13+
import json
14+
515
from typing import Any, Mapping
616

717
import numpy as np
@@ -31,7 +41,17 @@ def process_value(val: Any) -> Any:
3141
or "xarray" in original_module
3242
or "pandas" in original_module
3343
):
34-
return process_value(np.asanyarray(val).tolist())
44+
numpy_val = np.asanyarray(val)
45+
# fast paths for numeric types
46+
if numpy_val.dtype.kind in "iuf":
47+
return numpy_val.tolist()
48+
if numpy_val.dtype.kind == "c":
49+
return np.stack([numpy_val.real, numpy_val.imag], axis=-1).tolist()
50+
if numpy_val.dtype.kind == "b":
51+
return numpy_val.astype(int).tolist()
52+
53+
# should only be object arrays (tuples, etc)
54+
return process_value(numpy_val.tolist())
3555

3656
return val
3757

@@ -75,5 +95,8 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
7595
copied before type conversion, not modified
7696
"""
7797
with open(path, "w") as fd:
78-
for chunk in json.JSONEncoder().iterencode(process_dictionary(data)):
79-
fd.write(chunk)
98+
if UJSON_AVAILABLE:
99+
json.dump(process_dictionary(data), fd)
100+
else:
101+
for chunk in json.JSONEncoder().iterencode(process_dictionary(data)):
102+
fd.write(chunk)

test/test_json.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def test_basic_array(TMPDIR) -> None:
5757

5858

5959
def test_bool(TMPDIR) -> None:
60-
dict_bool = {"a": False, "b": True}
60+
dict_bool = {"a": False, "b": True, "c": np.array([True, False])}
6161
file_bool = os.path.join(TMPDIR, "bool.json")
62-
dict_exp = {"a": 0, "b": 1}
62+
dict_exp = {"a": 0, "b": 1, "c": [1, 0]}
6363
after = compare_before_after(file_bool, dict_bool, dict_exp)
6464
assert isinstance(after["a"], int)
6565
assert not isinstance(after["a"], bool)
@@ -135,18 +135,32 @@ def test_special_values(TMPDIR) -> None:
135135
]
136136
)
137137
}
138+
139+
# we want very specific values here
140+
json_string = dump_stan_json(dict_inf_nan)
141+
assert json_string.count("Infinity") == 8
142+
assert json_string.count("NaN") == 4
143+
assert json_string.count("-Infinity") == 4
144+
138145
dict_inf_nan_exp = {"a": [[-np.inf, np.inf, np.nan]] * 4}
139146
file_fin = os.path.join(TMPDIR, "inf.json")
140147
compare_before_after(file_fin, dict_inf_nan, dict_inf_nan_exp)
141148

142149

143150
def test_complex_numbers(TMPDIR) -> None:
144-
dict_complex = {"a": np.array([np.complex64(3), 3 + 4j])}
151+
dict_complex = {"a": [3 + 0j, 3 + 4j]}
145152
dict_complex_exp = {"a": [[3, 0], [3, 4]]}
146153
file_complex = os.path.join(TMPDIR, "complex.json")
147154
compare_before_after(file_complex, dict_complex, dict_complex_exp)
148155

149156

157+
def test_complex_numbers_np(TMPDIR) -> None:
158+
dict_complex = {"a": np.array([np.complex64(3), 3 + 4j])}
159+
dict_complex_exp = {"a": [[3, 0], [3, 4]]}
160+
file_complex = os.path.join(TMPDIR, "complex_np.json")
161+
compare_before_after(file_complex, dict_complex, dict_complex_exp)
162+
163+
150164
def test_tuples(TMPDIR) -> None:
151165
dict_tuples = {
152166
"a": (1, 2, 3),

0 commit comments

Comments
 (0)