Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions ci/Numba-array-api-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,12 @@ array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
array_api_tests/test_manipulation_functions.py::test_squeeze
array_api_tests/test_has_names.py::test_has_names[utility-diff]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_prod]
array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero]
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
array_api_tests/test_signatures.py::test_func_signature[diff]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
Expand Down Expand Up @@ -107,6 +105,5 @@ array_api_tests/test_searching_functions.py::test_count_nonzero
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
array_api_tests/test_signatures.py::test_func_signature[cumulative_prod]
array_api_tests/test_manipulation_functions.py::test_unstack
array_api_tests/test_signatures.py::test_func_signature[count_nonzero]
array_api_tests/test_signatures.py::test_func_signature[searchsorted]
2 changes: 2 additions & 0 deletions sparse/numba_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
sum,
tensordot,
tile,
unstack,
var,
vecdot,
zeros,
Expand Down Expand Up @@ -339,6 +340,7 @@
"zeros_like",
"repeat",
"tile",
"unstack",
]


Expand Down
31 changes: 31 additions & 0 deletions sparse/numba_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3186,3 +3186,34 @@ def tile(a, reps):
a = a.reshape(tuple(np.column_stack(([1] * ndim, shape)).reshape(-1)))
a = a.broadcast_to(tuple(np.column_stack((reps, shape)).reshape(-1)))
return a.reshape(tuple(np.multiply(reps, shape)))


def unstack(x, axis=0):
"""
Splits an array into a sequence of arrays along the given axis.

Parameters
----------
x : SparseArray
Input sparse arrays.
axis : int
Axis along which the array will be split

Returns
-------
out : Tuple[SparseArray,...]
Tuple of slices along the given dimension. All the arrays have the same shape.
"""
ndim = x.ndim

if not (-ndim <= axis < ndim):
raise ValueError(f"axis must be in range [-{ndim}, {ndim}), got {axis}")

if not isinstance(x, SparseArray):
raise TypeError("`a` must be a SparseArray.")

if axis < 0:
axis = ndim + axis
new_order = (axis,) + tuple(i for i in range(ndim) if i != axis)
x = x.transpose(new_order)
return (*x,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice short-cut!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was so happy to find it πŸ˜„

39 changes: 39 additions & 0 deletions sparse/numba_backend/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,3 +1981,42 @@ def test_tile(arr, reps):
result = sparse.tile(sparse_arr, reps).todense()

np.testing.assert_array_equal(result, expected, err_msg=f"Mismatch for shape={arr.shape}, reps={reps}")


@pytest.mark.parametrize("ndim", range(1, 5))
@pytest.mark.parametrize("shape_range", [3])
def test_unstack_matches_numpy(ndim, shape_range):
rng = np.random.default_rng(42)
shape = tuple(rng.integers(2, shape_range + 2) for _ in range(ndim))
a = rng.integers(0, 10, size=shape)
sparse_a = COO.from_numpy(a)

for axis in range(-ndim, ndim):
sparse_parts = sparse.unstack(sparse_a, axis=axis)
np_parts = np.moveaxis(a, axis, 0)

assert len(sparse_parts) == np_parts.shape[0], f"Wrong number of slices on axis {axis}"

for i, part in enumerate(sparse_parts):
expected = np_parts[i]
if isinstance(part, COO):
actual = part.todense()
elif np.isscalar(part):
actual = np.array(part)
else:
raise TypeError(f"Unexpected type returned from unstack: {type(part)}")

np.testing.assert_array_equal(actual, expected, err_msg=f"Mismatch at slice {i} on axis {axis}")


@pytest.mark.parametrize("axis", [-10, 10, 100, -100])
def test_unstack_invalid_axis(axis):
a = COO.from_numpy(np.arange(6).reshape(2, 3))
with pytest.raises(ValueError, match="axis must be in range"):
sparse.unstack(a, axis)


def test_unstack_invalid_type():
a = np.arange(6).reshape(2, 3) # not a sparse array
with pytest.raises(TypeError, match="must be a SparseArray"):
sparse.unstack(a, axis=0)
1 change: 1 addition & 0 deletions sparse/numba_backend/tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_namespace():
"uint8",
"unique_counts",
"unique_values",
"unstack",
"var",
"vecdot",
"where",
Expand Down
Loading