@@ -94,10 +94,11 @@ def test_cyto2_to_seg(data_dir, image_names, cellposemodel_fixture_24layer):
9494 clear_output (data_dir , image_names )
9595
9696
97- def test_class_3D_one_img (data_dir , image_names_3d , cellposemodel_fixture_2layer ):
97+ def test_class_3D_one_img_shape (data_dir , image_names_3d , cellposemodel_fixture_2layer ):
9898 clear_output (data_dir , image_names_3d )
9999
100100 img_file = data_dir / '3D' / image_names_3d [0 ]
101+ image_name = img_file .name
101102 img = io .imread_3D (img_file )
102103 masks_pred , flows_pred , _ = cellposemodel_fixture_2layer .eval (img , do_3D = True , channel_axis = - 1 , z_axis = 0 )
103104
@@ -125,12 +126,33 @@ def test_cli_2D(data_dir, image_names):
125126 clear_output (data_dir , image_names )
126127
127128
129+ @pytest .mark .parametrize ('diam, aniso' , [(None , 2.5 ), (25 , 2.5 ), (25 , None )])
128130@pytest .mark .slow
129- def test_cli_3D_diam (data_dir , image_names_3d ):
131+ def test_cli_3D_diam_anisotropy_shape (data_dir , image_names_3d , diam , aniso ):
130132 clear_output (data_dir , image_names_3d )
131- use_gpu = torch .cuda .is_available ()
133+ use_gpu = torch .cuda .is_available () or torch . backends . mps . is_available ()
132134 gpu_string = "--use_gpu" if use_gpu else ""
133- cmd = f"python -m cellpose --image_path { str (data_dir / '3D' / image_names_3d [0 ])} --do_3D --diameter 25 --save_tif { gpu_string } --verbose"
135+ anisotropy_text = f" { '--anisotropy ' + str (aniso ) if aniso else '' } "
136+ diam_text = f" { '--diameter ' + str (diam ) if diam else '' } "
137+ cmd = f"python -m cellpose --image_path { str (data_dir / '3D' / image_names_3d [0 ])} --do_3D --save_tif { gpu_string } --verbose" + anisotropy_text + diam_text
138+ print (cmd )
139+ try :
140+ cmd_stdout = check_output (cmd , stderr = STDOUT , shell = True ).decode ()
141+ print (cmd_stdout )
142+ except Exception as e :
143+ print (e )
144+ raise ValueError (e )
145+ compare_mask_shapes (data_dir , image_names_3d [0 ], "3D" )
146+ clear_output (data_dir , image_names_3d )
147+
148+
149+ @pytest .mark .slow
150+ def test_cli_3D_one_img (data_dir , image_names_3d ):
151+ clear_output (data_dir , image_names_3d )
152+ use_gpu = torch .cuda .is_available () or torch .backends .mps .is_available ()
153+ gpu_string = "--use_gpu" if use_gpu else ""
154+ cmd = f"python -m cellpose --image_path { str (data_dir / '3D' / image_names_3d [0 ])} --do_3D --save_tif { gpu_string } --verbose"
155+ print (cmd )
134156 try :
135157 cmd_stdout = check_output (cmd , stderr = STDOUT , shell = True ).decode ()
136158 print (cmd_stdout )
@@ -178,6 +200,8 @@ def compare_masks_cp4(data_dir, image_names, runtype):
178200 """
179201 data_dir_2D = data_dir .joinpath ("2D" )
180202 data_dir_3D = data_dir .joinpath ("3D" )
203+ if not isinstance (image_names , list ):
204+ image_names = [image_names ]
181205 for image_name in image_names :
182206 check = False
183207 if "2D" in runtype and "2D" in image_name :
@@ -215,3 +239,39 @@ def compare_masks_cp4(data_dir, image_names, runtype):
215239 else :
216240 print ("ERROR: no output file of name %s found" % output_test )
217241 assert False
242+
243+
244+ def compare_mask_shapes (data_dir , image_names , runtype ):
245+ """
246+ Helper function to check if outputs given by a test are exactly the same
247+ as the ground truth outputs.
248+ """
249+ data_dir_2D = data_dir .joinpath ("2D" )
250+ data_dir_3D = data_dir .joinpath ("3D" )
251+ if not isinstance (image_names , list ):
252+ image_names = [image_names ]
253+ for image_name in image_names :
254+ check = False
255+ if "2D" in runtype and "2D" in image_name :
256+ image_file = str (data_dir_2D .joinpath (image_name ))
257+ name = os .path .splitext (image_file )[0 ]
258+ output_test = name + "_cp_masks.png"
259+ output_true = name + "_cp4_gt_masks.png"
260+ check = True
261+ elif "3D" in runtype and "3D" in image_name :
262+ image_file = str (data_dir_3D .joinpath (image_name ))
263+ name = os .path .splitext (image_file )[0 ]
264+ output_test = name + "_cp_masks.tif"
265+ output_true = name + "_cp4_gt_masks.tif"
266+ check = True
267+
268+ if check :
269+ if os .path .exists (output_test ):
270+ print ("checking output %s" % output_test )
271+ masks_test = io .imread (output_test )
272+ masks_true = io .imread (output_true )
273+
274+ assert all ([a == b for a , b in zip (masks_test .shape , masks_true .shape )]), f'mask shape mismatch: { masks_test .shape } =/= { masks_true .shape } '
275+ else :
276+ print ("ERROR: no output file of name %s found" % output_test )
277+ assert False
0 commit comments