@@ -464,19 +464,19 @@ def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
464464class CellposeDenoiseModel ():
465465 """ model to run Cellpose and Image restoration """
466466 def __init__ (self , gpu = False , pretrained_model = False , model_type = None ,
467- restore_type = "denoise_cyto3" , chan2_denoise = False ,
467+ restore_type = "denoise_cyto3" , chan2_restore = False ,
468468 device = None ):
469469
470470 self .dn = DenoiseModel (gpu = gpu , model_type = restore_type ,
471- chan2 = chan2_denoise , device = device )
471+ chan2 = chan2_restore , device = device )
472472 self .cp = CellposeModel (gpu = gpu , model_type = model_type ,
473473 pretrained_model = pretrained_model , device = device )
474474
475475 def eval (self , x , batch_size = 8 , channels = None , channel_axis = None , z_axis = None ,
476476 normalize = True , rescale = None , diameter = None , tile = True , tile_overlap = 0.1 ,
477- resample = True , invert = False , flow_threshold = 0.4 , cellprob_threshold = 0.0 ,
478- do_3D = False , anisotropy = None , stitch_threshold = 0.0 , min_size = 15 ,
479- niter = None , interp = True ):
477+ augment = False , resample = True , invert = False , flow_threshold = 0.4 ,
478+ cellprob_threshold = 0.0 , do_3D = False , anisotropy = None , stitch_threshold = 0.0 ,
479+ min_size = 15 , niter = None , interp = True ):
480480 """
481481 Restore array or list of images using the image restoration model, and then segment.
482482
@@ -510,6 +510,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
510510 if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
511511 tile (bool, optional): tiles image to ensure GPU/CPU memory usage limited (recommended). Defaults to True.
512512 tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
513+ augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
513514 resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
514515 invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
515516 flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
@@ -549,7 +550,8 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
549550 diameter = self .dn .diam_mean if self .dn .ratio > 1 else diameter
550551 masks , flows , styles = self .cp .eval (img_restore , batch_size = batch_size , channels = channels_new , channel_axis = - 1 ,
551552 normalize = normalize_params , rescale = rescale , diameter = diameter ,
552- tile = tile , tile_overlap = tile_overlap , resample = resample , invert = invert ,
553+ tile = tile , tile_overlap = tile_overlap , augment = augment ,
554+ resample = resample , invert = invert ,
553555 flow_threshold = flow_threshold , cellprob_threshold = cellprob_threshold ,
554556 do_3D = do_3D , anisotropy = anisotropy , stitch_threshold = stitch_threshold ,
555557 min_size = min_size , niter = niter , interp = interp )
@@ -644,7 +646,7 @@ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
644646 )
645647 if chan2 and builtin :
646648 chan2_path = model_path (os .path .split (self .pretrained_model )[- 1 ].split ("_" )[0 ] + "_nuclei" )
647- print (f"loading model for chan2: { os .path .split (str (chan2_path )[- 1 ]) } " )
649+ print (f"loading model for chan2: { os .path .split (str (chan2_path )) [- 1 ]} " )
648650 self .net_chan2 = CPnet (self .nbase , self .nclasses , sz = 3 ,
649651 mkldnn = self .mkldnn , max_pool = True ,
650652 diam_mean = 17. ).to (self .device )
0 commit comments