Skip to content

Commit af83b26

Browse files
authored
Adding unstack function (#883)
Signed-off-by: Pradyot Ranjan <[email protected]>
1 parent 92842c2 commit af83b26

File tree

5 files changed

+73
-3
lines changed

5 files changed

+73
-3
lines changed

ci/Numba-array-api-xfails.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,12 @@ array_api_tests/test_creation_functions.py::test_empty_like
7070
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
7171
array_api_tests/test_manipulation_functions.py::test_squeeze
7272
array_api_tests/test_has_names.py::test_has_names[utility-diff]
73-
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
7473
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
7574
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_prod]
7675
array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
7776
array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero]
7877
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
7978
array_api_tests/test_signatures.py::test_func_signature[diff]
80-
array_api_tests/test_signatures.py::test_func_signature[unstack]
8179
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
8280
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
8381
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
@@ -107,6 +105,5 @@ array_api_tests/test_searching_functions.py::test_count_nonzero
107105
array_api_tests/test_searching_functions.py::test_searchsorted
108106
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
109107
array_api_tests/test_signatures.py::test_func_signature[cumulative_prod]
110-
array_api_tests/test_manipulation_functions.py::test_unstack
111108
array_api_tests/test_signatures.py::test_func_signature[count_nonzero]
112109
array_api_tests/test_signatures.py::test_func_signature[searchsorted]

sparse/numba_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
sum,
129129
tensordot,
130130
tile,
131+
unstack,
131132
var,
132133
vecdot,
133134
zeros,
@@ -339,6 +340,7 @@
339340
"zeros_like",
340341
"repeat",
341342
"tile",
343+
"unstack",
342344
]
343345

344346

sparse/numba_backend/_common.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3186,3 +3186,34 @@ def tile(a, reps):
31863186
a = a.reshape(tuple(np.column_stack(([1] * ndim, shape)).reshape(-1)))
31873187
a = a.broadcast_to(tuple(np.column_stack((reps, shape)).reshape(-1)))
31883188
return a.reshape(tuple(np.multiply(reps, shape)))
3189+
3190+
3191+
def unstack(x, axis=0):
3192+
"""
3193+
Splits an array into a sequence of arrays along the given axis.
3194+
3195+
Parameters
3196+
----------
3197+
x : SparseArray
3198+
Input sparse arrays.
3199+
axis : int
3200+
Axis along which the array will be split
3201+
3202+
Returns
3203+
-------
3204+
out : Tuple[SparseArray,...]
3205+
Tuple of slices along the given dimension. All the arrays have the same shape.
3206+
"""
3207+
ndim = x.ndim
3208+
3209+
if not (-ndim <= axis < ndim):
3210+
raise ValueError(f"axis must be in range [-{ndim}, {ndim}), got {axis}")
3211+
3212+
if not isinstance(x, SparseArray):
3213+
raise TypeError("`a` must be a SparseArray.")
3214+
3215+
if axis < 0:
3216+
axis = ndim + axis
3217+
new_order = (axis,) + tuple(i for i in range(ndim) if i != axis)
3218+
x = x.transpose(new_order)
3219+
return (*x,)

sparse/numba_backend/tests/test_coo.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,3 +1981,42 @@ def test_tile(arr, reps):
19811981
result = sparse.tile(sparse_arr, reps).todense()
19821982

19831983
np.testing.assert_array_equal(result, expected, err_msg=f"Mismatch for shape={arr.shape}, reps={reps}")
1984+
1985+
1986+
@pytest.mark.parametrize("ndim", range(1, 5))
1987+
@pytest.mark.parametrize("shape_range", [3])
1988+
def test_unstack_matches_numpy(ndim, shape_range):
1989+
rng = np.random.default_rng(42)
1990+
shape = tuple(rng.integers(2, shape_range + 2) for _ in range(ndim))
1991+
a = rng.integers(0, 10, size=shape)
1992+
sparse_a = COO.from_numpy(a)
1993+
1994+
for axis in range(-ndim, ndim):
1995+
sparse_parts = sparse.unstack(sparse_a, axis=axis)
1996+
np_parts = np.moveaxis(a, axis, 0)
1997+
1998+
assert len(sparse_parts) == np_parts.shape[0], f"Wrong number of slices on axis {axis}"
1999+
2000+
for i, part in enumerate(sparse_parts):
2001+
expected = np_parts[i]
2002+
if isinstance(part, COO):
2003+
actual = part.todense()
2004+
elif np.isscalar(part):
2005+
actual = np.array(part)
2006+
else:
2007+
raise TypeError(f"Unexpected type returned from unstack: {type(part)}")
2008+
2009+
np.testing.assert_array_equal(actual, expected, err_msg=f"Mismatch at slice {i} on axis {axis}")
2010+
2011+
2012+
@pytest.mark.parametrize("axis", [-10, 10, 100, -100])
2013+
def test_unstack_invalid_axis(axis):
2014+
a = COO.from_numpy(np.arange(6).reshape(2, 3))
2015+
with pytest.raises(ValueError, match="axis must be in range"):
2016+
sparse.unstack(a, axis)
2017+
2018+
2019+
def test_unstack_invalid_type():
2020+
a = np.arange(6).reshape(2, 3) # not a sparse array
2021+
with pytest.raises(TypeError, match="must be a SparseArray"):
2022+
sparse.unstack(a, axis=0)

sparse/numba_backend/tests/test_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def test_namespace():
166166
"uint8",
167167
"unique_counts",
168168
"unique_values",
169+
"unstack",
169170
"var",
170171
"vecdot",
171172
"where",

0 commit comments

Comments
 (0)