Skip to content

Commit 9d59b2e

Browse files
authored
Merge pull request #1193 from MouseLand/ansiotropy_fix
Ansiotropy fix
2 parents 266fbe0 + 0ed9592 commit 9d59b2e

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

cellpose/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ def eval(self, x, batch_size=8, resample=None, channels=None, channel_axis=None,
339339

340340
masks, dP, cellprob = masks.squeeze(), dP.squeeze(), cellprob.squeeze()
341341

342-
# undo diameter resizing:
343-
if image_scaling is not None:
342+
# undo resizing:
343+
if image_scaling is not None or anisotropy is not None:
344344
if do_3D:
345345
# Rescale xy then xz:
346346
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)

tests/test_output.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,11 @@ def test_cyto2_to_seg(data_dir, image_names, cellposemodel_fixture_24layer):
9494
clear_output(data_dir, image_names)
9595

9696

97-
def test_class_3D_one_img(data_dir, image_names_3d, cellposemodel_fixture_2layer):
97+
def test_class_3D_one_img_shape(data_dir, image_names_3d, cellposemodel_fixture_2layer):
9898
clear_output(data_dir, image_names_3d)
9999

100100
img_file = data_dir / '3D' / image_names_3d[0]
101+
image_name = img_file.name
101102
img = io.imread_3D(img_file)
102103
masks_pred, flows_pred, _ = cellposemodel_fixture_2layer.eval(img, do_3D=True, channel_axis=-1, z_axis=0)
103104

@@ -125,12 +126,33 @@ def test_cli_2D(data_dir, image_names):
125126
clear_output(data_dir, image_names)
126127

127128

129+
@pytest.mark.parametrize('diam, aniso', [(None, 2.5), (25, 2.5), (25, None)])
128130
@pytest.mark.slow
129-
def test_cli_3D_diam(data_dir, image_names_3d):
131+
def test_cli_3D_diam_anisotropy_shape(data_dir, image_names_3d, diam, aniso):
130132
clear_output(data_dir, image_names_3d)
131-
use_gpu = torch.cuda.is_available()
133+
use_gpu = torch.cuda.is_available() or torch.backends.mps.is_available()
132134
gpu_string = "--use_gpu" if use_gpu else ""
133-
cmd = f"python -m cellpose --image_path {str(data_dir / '3D' / image_names_3d[0])} --do_3D --diameter 25 --save_tif {gpu_string} --verbose"
135+
anisotropy_text = f" {'--anisotropy ' + str(aniso) if aniso else ''}"
136+
diam_text = f" {'--diameter ' + str(diam) if diam else ''}"
137+
cmd = f"python -m cellpose --image_path {str(data_dir / '3D' / image_names_3d[0])} --do_3D --save_tif {gpu_string} --verbose" + anisotropy_text + diam_text
138+
print(cmd)
139+
try:
140+
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()
141+
print(cmd_stdout)
142+
except Exception as e:
143+
print(e)
144+
raise ValueError(e)
145+
compare_mask_shapes(data_dir, image_names_3d[0], "3D")
146+
clear_output(data_dir, image_names_3d)
147+
148+
149+
@pytest.mark.slow
150+
def test_cli_3D_one_img(data_dir, image_names_3d):
151+
clear_output(data_dir, image_names_3d)
152+
use_gpu = torch.cuda.is_available() or torch.backends.mps.is_available()
153+
gpu_string = "--use_gpu" if use_gpu else ""
154+
cmd = f"python -m cellpose --image_path {str(data_dir / '3D' / image_names_3d[0])} --do_3D --save_tif {gpu_string} --verbose"
155+
print(cmd)
134156
try:
135157
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()
136158
print(cmd_stdout)
@@ -178,6 +200,8 @@ def compare_masks_cp4(data_dir, image_names, runtype):
178200
"""
179201
data_dir_2D = data_dir.joinpath("2D")
180202
data_dir_3D = data_dir.joinpath("3D")
203+
if not isinstance(image_names, list):
204+
image_names = [image_names]
181205
for image_name in image_names:
182206
check = False
183207
if "2D" in runtype and "2D" in image_name:
@@ -215,3 +239,39 @@ def compare_masks_cp4(data_dir, image_names, runtype):
215239
else:
216240
print("ERROR: no output file of name %s found" % output_test)
217241
assert False
242+
243+
244+
def compare_mask_shapes(data_dir, image_names, runtype):
245+
"""
246+
Helper function to check if outputs given by a test are exactly the same
247+
as the ground truth outputs.
248+
"""
249+
data_dir_2D = data_dir.joinpath("2D")
250+
data_dir_3D = data_dir.joinpath("3D")
251+
if not isinstance(image_names, list):
252+
image_names = [image_names]
253+
for image_name in image_names:
254+
check = False
255+
if "2D" in runtype and "2D" in image_name:
256+
image_file = str(data_dir_2D.joinpath(image_name))
257+
name = os.path.splitext(image_file)[0]
258+
output_test = name + "_cp_masks.png"
259+
output_true = name + "_cp4_gt_masks.png"
260+
check = True
261+
elif "3D" in runtype and "3D" in image_name:
262+
image_file = str(data_dir_3D.joinpath(image_name))
263+
name = os.path.splitext(image_file)[0]
264+
output_test = name + "_cp_masks.tif"
265+
output_true = name + "_cp4_gt_masks.tif"
266+
check = True
267+
268+
if check:
269+
if os.path.exists(output_test):
270+
print("checking output %s" % output_test)
271+
masks_test = io.imread(output_test)
272+
masks_true = io.imread(output_true)
273+
274+
assert all([a == b for a, b in zip(masks_test.shape, masks_true.shape)]), f'mask shape mismatch: {masks_test.shape} =/= {masks_true.shape}'
275+
else:
276+
print("ERROR: no output file of name %s found" % output_test)
277+
assert False

0 commit comments

Comments
 (0)