Skip to content

Commit 90c5cd1

Browse files
fix(remap): in_place always making a copy (#52)
* fix(remap): in_place=True would always make a copy * test: check if in_place actually changes underlying array * docs: use more natural cython binding for remap, renumber for help
1 parent 623eef1 commit 90c5cd1

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

automated_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,19 @@ def test_remap_broken():
213213
res = fastremap.remap(labels, {5:5}, preserve_missing_labels=True)
214214
assert np.all(res == labels)
215215

216+
def test_remap_in_place_broken():
217+
data = np.array([0])
218+
result = fastremap.remap(data, {0: 1}, in_place=True)
219+
assert result[0] == 1
220+
assert data[0] == 1
221+
222+
def test_renumber_in_place_broken():
223+
data = np.array([5])
224+
result, mapping = fastremap.renumber(data, in_place=True)
225+
assert result[0] == 1
226+
assert data[0] == 1
227+
assert mapping == { 0: 0, 5: 1 }
228+
216229
@pytest.mark.parametrize("dtype", DTYPES)
217230
@pytest.mark.parametrize("in_place", [ True, False ])
218231
def test_mask(dtype, in_place):

fastremap/fastremap.pyx

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,9 @@ def indices(cnp.ndarray[NUMBER, cast=True, ndim=1] arr, NUMBER value):
118118

119119
return np.asarray(all_indices, dtype=np.uint64)
120120

121+
@cython.binding(True)
121122
def renumber(arr, start=1, preserve_zero=True, in_place=False):
122123
"""
123-
renumber(arr, start=1, preserve_zero=True, in_place=False)
124-
125124
Given an array of integers, renumber all the unique values starting
126125
from 1. This can allow us to reduce the size of the data width required
127126
to represent it.
@@ -652,11 +651,9 @@ def _inverse_component_map(
652651

653652
return remap
654653

654+
@cython.binding(True)
655655
def remap(arr, table, preserve_missing_labels=False, in_place=False):
656656
"""
657-
remap(cnp.ndarray[COMPLEX_NUMBER] arr, dict table,
658-
preserve_missing_labels=False, in_place=False)
659-
660657
Remap an input numpy array in-place according to the values in the given
661658
dictionary "table".
662659
@@ -687,10 +684,12 @@ def remap(arr, table, preserve_missing_labels=False, in_place=False):
687684
arr = refit(arr, fit_value, increase_only=True)
688685

689686
make_copy = (
690-
(not in_place)
691-
or (original_dtype == arr.dtype) # avoid two copies b/c copied in refit
692-
or (not arr.flags.writeable)
693-
or not (arr.flags.f_contiguous or arr.flags.c_contiguous)
687+
(
688+
(not in_place)
689+
or (not arr.flags.writeable)
690+
or not (arr.flags.f_contiguous or arr.flags.c_contiguous)
691+
)
692+
and (original_dtype == arr.dtype) # avoid two copies b/c copied in refit if dtype doesn't match
694693
)
695694

696695
if make_copy:

0 commit comments

Comments
 (0)