Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions ci/Numba-array-api-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,12 @@ array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
array_api_tests/test_manipulation_functions.py::test_squeeze
array_api_tests/test_has_names.py::test_has_names[utility-diff]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_prod]
array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero]
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
array_api_tests/test_signatures.py::test_func_signature[diff]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
Expand Down Expand Up @@ -107,6 +105,5 @@ array_api_tests/test_searching_functions.py::test_count_nonzero
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
array_api_tests/test_signatures.py::test_func_signature[cumulative_prod]
array_api_tests/test_manipulation_functions.py::test_unstack
array_api_tests/test_signatures.py::test_func_signature[count_nonzero]
array_api_tests/test_signatures.py::test_func_signature[searchsorted]
2 changes: 2 additions & 0 deletions sparse/numba_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
sum,
tensordot,
tile,
unstack,
var,
vecdot,
zeros,
Expand Down Expand Up @@ -339,6 +340,7 @@
"zeros_like",
"repeat",
"tile",
"unstack",
]


Expand Down
31 changes: 31 additions & 0 deletions sparse/numba_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3186,3 +3186,34 @@ def tile(a, reps):
a = a.reshape(tuple(np.column_stack(([1] * ndim, shape)).reshape(-1)))
a = a.broadcast_to(tuple(np.column_stack((reps, shape)).reshape(-1)))
return a.reshape(tuple(np.multiply(reps, shape)))


def unstack(x, axis=0):
"""
Splits an array into a sequence of arrays along the given axis.

Parameters
----------
x : SparseArray
Input sparse arrays.
axis : int
Axis along which the array will be split

Returns
-------
out : Tuple[SparseArray,...]
Tuple of slices along the given dimension. All the arrays have the same shape.
"""
ndim = x.ndim

if not (-ndim <= axis < ndim):
raise ValueError(f"axis must be in range [-{ndim}, {ndim}), got {axis}")

if not isinstance(x, SparseArray):
raise TypeError("`a` must be a SparseArray.")

if axis < 0:
axis = ndim + axis
new_order = (axis,) + tuple(i for i in range(ndim) if i != axis)
x = x.transpose(new_order)
return (*x,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice short-cut!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was so happy to find it πŸ˜„

39 changes: 39 additions & 0 deletions sparse/numba_backend/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,3 +1981,42 @@ def test_tile(arr, reps):
result = sparse.tile(sparse_arr, reps).todense()

np.testing.assert_array_equal(result, expected, err_msg=f"Mismatch for shape={arr.shape}, reps={reps}")


@pytest.mark.parametrize("ndim", range(1, 5))
@pytest.mark.parametrize("shape_range", [3])
def test_unstack_matches_numpy(ndim, shape_range):
rng = np.random.default_rng(42)
shape = tuple(rng.integers(2, shape_range + 2) for _ in range(ndim))
a = rng.integers(0, 10, size=shape)
sparse_a = COO.from_numpy(a)

for axis in range(-ndim, ndim):
sparse_parts = sparse.unstack(sparse_a, axis=axis)
np_parts = np.moveaxis(a, axis, 0)

assert len(sparse_parts) == np_parts.shape[0], f"Wrong number of slices on axis {axis}"

for i, part in enumerate(sparse_parts):
expected = np_parts[i]
if isinstance(part, COO):
actual = part.todense()
elif np.isscalar(part):
actual = np.array(part)
else:
raise TypeError(f"Unexpected type returned from unstack: {type(part)}")

np.testing.assert_array_equal(actual, expected, err_msg=f"Mismatch at slice {i} on axis {axis}")


@pytest.mark.parametrize("axis", [-10, 10, 100, -100])
def test_unstack_invalid_axis(axis):
a = COO.from_numpy(np.arange(6).reshape(2, 3))
with pytest.raises(ValueError, match="axis must be in range"):
sparse.unstack(a, axis)


def test_unstack_invalid_type():
a = np.arange(6).reshape(2, 3) # not a sparse array
with pytest.raises(TypeError, match="must be a SparseArray"):
sparse.unstack(a, axis=0)
1 change: 1 addition & 0 deletions sparse/numba_backend/tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_namespace():
"uint8",
"unique_counts",
"unique_values",
"unstack",
"var",
"vecdot",
"where",
Expand Down
Loading