Skip to content

Commit 93b5071

Browse files
authored
Adding tile function (#880)
1 parent df97f54 commit 93b5071

File tree

5 files changed

+67
-3
lines changed

5 files changed

+67
-3
lines changed

ci/Numba-array-api-xfails.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,13 @@ array_api_tests/test_creation_functions.py::test_empty_like
7070
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
7171
array_api_tests/test_manipulation_functions.py::test_squeeze
7272
array_api_tests/test_has_names.py::test_has_names[utility-diff]
73-
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
7473
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
7574
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
7675
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_prod]
7776
array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
7877
array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero]
7978
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
8079
array_api_tests/test_signatures.py::test_func_signature[diff]
81-
array_api_tests/test_signatures.py::test_func_signature[tile]
8280
array_api_tests/test_signatures.py::test_func_signature[unstack]
8381
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
8482
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
@@ -107,7 +105,6 @@ array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[1]
107105
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[None]
108106
array_api_tests/test_searching_functions.py::test_count_nonzero
109107
array_api_tests/test_searching_functions.py::test_searchsorted
110-
array_api_tests/test_manipulation_functions.py::test_tile
111108
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
112109
array_api_tests/test_signatures.py::test_func_signature[cumulative_prod]
113110
array_api_tests/test_manipulation_functions.py::test_unstack

sparse/numba_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
std,
128128
sum,
129129
tensordot,
130+
tile,
130131
var,
131132
vecdot,
132133
zeros,
@@ -337,6 +338,7 @@
337338
"zeros",
338339
"zeros_like",
339340
"repeat",
341+
"tile",
340342
]
341343

342344

sparse/numba_backend/_common.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3147,3 +3147,42 @@ def repeat(a, repeats, axis=None):
31473147
if not axis_is_none:
31483148
return a.reshape(new_shape)
31493149
return a.reshape(new_shape).flatten()
3150+
3151+
3152+
def tile(a, reps):
3153+
"""
3154+
Constructs an array by tiling an input array.
3155+
3156+
Parameters
3157+
----------
3158+
a : SparseArray
3159+
Input sparse arrays.
3160+
reps : int or tuple[int, ...]
3161+
The number of repetitions for each dimension.
3162+
If an integer, the same number of repetitions is applied to all dimensions.
3163+
3164+
Returns
3165+
-------
3166+
out : SparseArray
3167+
A tiled output array.
3168+
"""
3169+
if not isinstance(a, SparseArray):
3170+
a = as_coo(a)
3171+
3172+
if isinstance(reps, int):
3173+
reps = (reps,)
3174+
reps = tuple(reps)
3175+
3176+
if a.ndim == 0:
3177+
a = a.reshape((1,))
3178+
3179+
if len(reps) < a.ndim:
3180+
reps = (1,) * (a.ndim - len(reps)) + reps
3181+
elif len(reps) > a.ndim:
3182+
a = a.reshape((1,) * (len(reps) - a.ndim) + a.shape)
3183+
3184+
shape = a.shape
3185+
ndim = len(reps)
3186+
a = a.reshape(tuple(np.column_stack(([1] * ndim, shape)).reshape(-1)))
3187+
a = a.broadcast_to(tuple(np.column_stack((reps, shape)).reshape(-1)))
3188+
return a.reshape(tuple(np.multiply(reps, shape)))

sparse/numba_backend/tests/test_coo.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,3 +1956,28 @@ def test_repeat(ndim, repeats):
19561956
print(f"Expected: {expected}, Actual: {actual}")
19571957
assert actual.shape == expected.shape
19581958
np.testing.assert_array_equal(actual, expected)
1959+
1960+
1961+
def test_tile_invalid_input():
1962+
a = np.eye(3)
1963+
assert isinstance(sparse.tile(a, 2), sparse.COO)
1964+
1965+
1966+
@pytest.mark.parametrize(
1967+
"arr,reps",
1968+
[
1969+
(np.array([1, 2, 3]), (3,)),
1970+
(np.array([4, 5, 6, 7]), 3),
1971+
(np.array(1), 3),
1972+
(np.array([[1, 2], [3, 4]]), (2, 2)),
1973+
(np.array([[[1], [2]], [[3], [4]]]), (2, 1, 2)),
1974+
(np.random.default_rng(0).integers(0, 10, (2, 1, 3)), (2, 2, 2)),
1975+
(np.random.default_rng(1).integers(0, 5, (1, 3, 1, 2)), (2, 1, 3, 1)),
1976+
],
1977+
)
1978+
def test_tile(arr, reps):
1979+
sparse_arr = sparse.COO.from_numpy(arr)
1980+
expected = np.tile(arr, reps)
1981+
result = sparse.tile(sparse_arr, reps).todense()
1982+
1983+
np.testing.assert_array_equal(result, expected, err_msg=f"Mismatch for shape={arr.shape}, reps={reps}")

sparse/numba_backend/tests/test_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def test_namespace():
156156
"tan",
157157
"tanh",
158158
"tensordot",
159+
"tile",
159160
"tril",
160161
"triu",
161162
"trunc",

0 commit comments

Comments
 (0)