diff --git a/ci/Numba-array-api-xfails.txt b/ci/Numba-array-api-xfails.txt index dd8ec9e1..94a55c29 100644 --- a/ci/Numba-array-api-xfails.txt +++ b/ci/Numba-array-api-xfails.txt @@ -70,14 +70,12 @@ array_api_tests/test_creation_functions.py::test_empty_like array_api_tests/test_data_type_functions.py::test_finfo[complex64] array_api_tests/test_manipulation_functions.py::test_squeeze array_api_tests/test_has_names.py::test_has_names[utility-diff] -array_api_tests/test_has_names.py::test_has_names[manipulation-unstack] array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum] array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_prod] array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis] array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero] array_api_tests/test_has_names.py::test_has_names[searching-searchsorted] array_api_tests/test_signatures.py::test_func_signature[diff] -array_api_tests/test_signatures.py::test_func_signature[unstack] array_api_tests/test_signatures.py::test_func_signature[take_along_axis] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] @@ -107,6 +105,5 @@ array_api_tests/test_searching_functions.py::test_count_nonzero array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_signatures.py::test_func_signature[cumulative_sum] array_api_tests/test_signatures.py::test_func_signature[cumulative_prod] -array_api_tests/test_manipulation_functions.py::test_unstack array_api_tests/test_signatures.py::test_func_signature[count_nonzero] array_api_tests/test_signatures.py::test_func_signature[searchsorted] diff --git a/sparse/numba_backend/__init__.py b/sparse/numba_backend/__init__.py index efbfa1ca..9226da18 100644 --- a/sparse/numba_backend/__init__.py +++ b/sparse/numba_backend/__init__.py @@ -128,6 +128,7 @@ sum, tensordot, tile, + unstack, var, vecdot, zeros, @@ -339,6 +340,7 @@ "zeros_like", "repeat", "tile", + "unstack", ] diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index 9f1b051a..24b17686 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -3186,3 +3186,34 @@ def tile(a, reps): a = a.reshape(tuple(np.column_stack(([1] * ndim, shape)).reshape(-1))) a = a.broadcast_to(tuple(np.column_stack((reps, shape)).reshape(-1))) return a.reshape(tuple(np.multiply(reps, shape))) + + +def unstack(x, axis=0): + """ + Splits an array into a sequence of arrays along the given axis. + + Parameters + ---------- + x : SparseArray + Input sparse arrays. + axis : int + Axis along which the array will be split + + Returns + ------- + out : Tuple[SparseArray,...] + Tuple of slices along the given dimension. All the arrays have the same shape. + """ + ndim = x.ndim + + if not (-ndim <= axis < ndim): + raise ValueError(f"axis must be in range [-{ndim}, {ndim}), got {axis}") + + if not isinstance(x, SparseArray): + raise TypeError("`a` must be a SparseArray.") + + if axis < 0: + axis = ndim + axis + new_order = (axis,) + tuple(i for i in range(ndim) if i != axis) + x = x.transpose(new_order) + return (*x,) diff --git a/sparse/numba_backend/tests/test_coo.py b/sparse/numba_backend/tests/test_coo.py index 767ba83d..50740882 100644 --- a/sparse/numba_backend/tests/test_coo.py +++ b/sparse/numba_backend/tests/test_coo.py @@ -1981,3 +1981,42 @@ def test_tile(arr, reps): result = sparse.tile(sparse_arr, reps).todense() np.testing.assert_array_equal(result, expected, err_msg=f"Mismatch for shape={arr.shape}, reps={reps}") + + +@pytest.mark.parametrize("ndim", range(1, 5)) +@pytest.mark.parametrize("shape_range", [3]) +def test_unstack_matches_numpy(ndim, shape_range): + rng = np.random.default_rng(42) + shape = tuple(rng.integers(2, shape_range + 2) for _ in range(ndim)) + a = rng.integers(0, 10, size=shape) + sparse_a = COO.from_numpy(a) + + for axis in range(-ndim, ndim): + sparse_parts = sparse.unstack(sparse_a, axis=axis) + np_parts = np.moveaxis(a, axis, 0) + + assert len(sparse_parts) == np_parts.shape[0], f"Wrong number of slices on axis {axis}" + + for i, part in enumerate(sparse_parts): + expected = np_parts[i] + if isinstance(part, COO): + actual = part.todense() + elif np.isscalar(part): + actual = np.array(part) + else: + raise TypeError(f"Unexpected type returned from unstack: {type(part)}") + + np.testing.assert_array_equal(actual, expected, err_msg=f"Mismatch at slice {i} on axis {axis}") + + +@pytest.mark.parametrize("axis", [-10, 10, 100, -100]) +def test_unstack_invalid_axis(axis): + a = COO.from_numpy(np.arange(6).reshape(2, 3)) + with pytest.raises(ValueError, match="axis must be in range"): + sparse.unstack(a, axis) + + +def test_unstack_invalid_type(): + a = np.arange(6).reshape(2, 3) # not a sparse array + with pytest.raises(TypeError, match="must be a SparseArray"): + sparse.unstack(a, axis=0) diff --git a/sparse/numba_backend/tests/test_namespace.py b/sparse/numba_backend/tests/test_namespace.py index 40beed81..40ce4fe1 100644 --- a/sparse/numba_backend/tests/test_namespace.py +++ b/sparse/numba_backend/tests/test_namespace.py @@ -166,6 +166,7 @@ def test_namespace(): "uint8", "unique_counts", "unique_values", + "unstack", "var", "vecdot", "where",