Skip to content

Commit 67ea4b1

Browse files
authored
Adding diff function (#888)
1 parent 2b5f7c4 commit 67ea4b1

File tree

5 files changed

+76
-2
lines changed

5 files changed

+76
-2
lines changed

ci/Numba-array-api-xfails.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,11 @@ array_api_tests/test_has_names.py::test_has_names[fft-irfftn]
6969
array_api_tests/test_creation_functions.py::test_empty_like
7070
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
7171
array_api_tests/test_manipulation_functions.py::test_squeeze
72-
array_api_tests/test_has_names.py::test_has_names[utility-diff]
7372
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
7473
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_prod]
7574
array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
7675
array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero]
7776
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
78-
array_api_tests/test_signatures.py::test_func_signature[diff]
7977
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
8078
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
8179
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

sparse/numba_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
can_cast,
9696
concat,
9797
concatenate,
98+
diff,
9899
dot,
99100
einsum,
100101
empty,
@@ -341,6 +342,7 @@
341342
"repeat",
342343
"tile",
343344
"unstack",
345+
"diff",
344346
]
345347

346348

sparse/numba_backend/_common.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3217,3 +3217,36 @@ def unstack(x, axis=0):
32173217
new_order = (axis,) + tuple(i for i in range(ndim) if i != axis)
32183218
x = x.transpose(new_order)
32193219
return (*x,)
3220+
3221+
3222+
def diff(x, axis=-1, n=1, prepend=None, append=None):
3223+
"""
3224+
Calculates the n-th discrete difference along the given axis.
3225+
3226+
Parameters
3227+
----------
3228+
x : SparseArray
3229+
Input sparse arrays.
3230+
n : int
3231+
The number of times values are differenced. Default: 1.
3232+
axis : int
3233+
The axis along which the difference is taken. Default: -1.
3234+
3235+
Returns
3236+
-------
3237+
out : SparseArray
3238+
An array containing the n-th discrete difference along the given axis.
3239+
"""
3240+
if not isinstance(x, SparseArray):
3241+
raise TypeError("`x` must be a SparseArray.")
3242+
3243+
if axis < 0:
3244+
axis = x.ndim + axis
3245+
if prepend is not None:
3246+
x = concatenate([prepend, x], axis=axis)
3247+
if append is not None:
3248+
x = concatenate([x, append], axis=axis)
3249+
result = x
3250+
for _ in range(n):
3251+
result = result[(slice(None),) * axis + (slice(1, None),)] - result[(slice(None),) * axis + (slice(None, -1),)]
3252+
return result

sparse/numba_backend/tests/test_coo.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,3 +2030,43 @@ def test_unstack_invalid_type():
20302030
a = np.arange(6).reshape(2, 3) # not a sparse array
20312031
with pytest.raises(TypeError, match="must be a SparseArray"):
20322032
sparse.unstack(a, axis=0)
2033+
2034+
2035+
@pytest.mark.parametrize("ndim", range(1, 4))
2036+
@pytest.mark.parametrize("shape_range", [3])
2037+
@pytest.mark.parametrize("n", [1, 2])
2038+
@pytest.mark.parametrize("use_prepend, use_append", [(False, False), (True, False), (False, True), (True, True)])
2039+
def test_diff_matches_numpy(ndim, shape_range, n, use_prepend, use_append):
2040+
rng = np.random.default_rng(42)
2041+
shape = tuple(rng.integers(2, shape_range + 2) for _ in range(ndim))
2042+
x = rng.integers(0, 10, size=shape)
2043+
sparse_x = COO.from_numpy(x)
2044+
2045+
for axis in range(-ndim, ndim):
2046+
prepend = rng.integers(0, 10, size=x.shape).astype(x.dtype) if use_prepend else None
2047+
append = rng.integers(0, 10, size=x.shape).astype(x.dtype) if use_append else None
2048+
2049+
sparse_prepend = COO.from_numpy(prepend) if prepend is not None else None
2050+
sparse_append = COO.from_numpy(append) if append is not None else None
2051+
2052+
sparse_result = sparse.diff(sparse_x, axis=axis, n=n, prepend=sparse_prepend, append=sparse_append)
2053+
2054+
kwargs = {}
2055+
if prepend is not None:
2056+
kwargs["prepend"] = prepend
2057+
if append is not None:
2058+
kwargs["append"] = append
2059+
2060+
dense_result = np.diff(x, axis=axis, n=n, **kwargs)
2061+
2062+
np.testing.assert_array_equal(
2063+
sparse_result.todense(),
2064+
dense_result,
2065+
err_msg=f"Mismatch at axis={axis}, n={n}, prepend={use_prepend}, append={use_append}",
2066+
)
2067+
2068+
2069+
def test_diff_invalid_type():
2070+
a = np.arange(6).reshape(2, 3)
2071+
with pytest.raises(TypeError, match="must be a SparseArray"):
2072+
sparse.diff(a)

sparse/numba_backend/tests/test_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def test_namespace():
5151
"cosh",
5252
"diagonal",
5353
"diagonalize",
54+
"diff",
5455
"divide",
5556
"dot",
5657
"e",

0 commit comments

Comments
 (0)