Skip to content

Commit 83f00ce

Browse files
Merge pull request #875 from MouseLand/cli_restore
adding CLI for restore
2 parents d4857f5 + 307bcf1 commit 83f00ce

File tree

13 files changed

+621
-253
lines changed

13 files changed

+621
-253
lines changed

cellpose/__main__.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from natsort import natsorted
88
from tqdm import tqdm
9-
from cellpose import utils, models, io, version_str, train
9+
from cellpose import utils, models, io, version_str, train, denoise
1010
from cellpose.cli import get_arg_parser
1111

1212
try:
@@ -90,9 +90,18 @@ def main():
9090
else:
9191
pretrained_model = args.pretrained_model
9292

93+
restore_type = args.restore_type
94+
if restore_type is not None:
95+
try:
96+
denoise.model_path(restore_type)
97+
except Exception as e:
98+
raise ValueError("restore_type invalid")
99+
if args.train or args.train_size:
100+
raise ValueError("restore_type cannot be used with training on CLI yet")
101+
93102
model_type = None
94103
if pretrained_model and not os.path.exists(pretrained_model):
95-
model_type = pretrained_model if pretrained_model is not None else "cyto"
104+
model_type = pretrained_model if pretrained_model is not None else "cyto3"
96105
model_strings = models.get_user_models()
97106
all_models = models.MODEL_NAMES.copy()
98107
all_models.extend(model_strings)
@@ -127,26 +136,39 @@ def main():
127136
">>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s"
128137
% (nimg, cstr0[channels[0]], cstr1[channels[1]]))
129138

130-
# handle built-in model exceptions; bacterial ones get no size model
131-
if builtin_size:
139+
# handle built-in model exceptions
140+
if builtin_size and restore_type is None:
132141
model = models.Cellpose(gpu=gpu, device=device, model_type=model_type)
133142
else:
143+
builtin_size = False
134144
if args.all_channels:
135145
channels = None
136146
pretrained_model = None if model_type is not None else pretrained_model
137-
model = models.CellposeModel(gpu=gpu, device=device,
138-
pretrained_model=pretrained_model,
139-
model_type=model_type)
147+
if restore_type is None:
148+
model = models.CellposeModel(gpu=gpu, device=device,
149+
pretrained_model=pretrained_model,
150+
model_type=model_type)
151+
else:
152+
model = denoise.CellposeDenoiseModel(gpu=gpu, device=device,
153+
pretrained_model=pretrained_model,
154+
model_type=model_type,
155+
restore_type=restore_type,
156+
chan2_restore=args.chan2_restore)
140157

141158
# handle diameters
142159
if args.diameter == 0:
143160
if builtin_size:
144161
diameter = None
145162
logger.info(">>>> estimating diameter for each image")
146163
else:
147-
logger.info(
148-
">>>> not using cyto, cyto2, or nuclei model, cannot auto-estimate diameter"
149-
)
164+
if restore_type is None:
165+
logger.info(
166+
">>>> not using cyto3, cyto, cyto2, or nuclei model, cannot auto-estimate diameter"
167+
)
168+
else:
169+
logger.info(
170+
">>>> cannot auto-estimate diameter for image restoration"
171+
)
150172
diameter = model.diam_labels
151173
logger.info(">>>> using diameter %0.3f for all images" % diameter)
152174
else:
@@ -168,17 +190,26 @@ def main():
168190
channel_axis=args.channel_axis, z_axis=args.z_axis,
169191
anisotropy=args.anisotropy, niter=args.niter)
170192
masks, flows = out[:2]
171-
if len(out) > 3:
193+
if len(out) > 3 and restore_type is None:
172194
diams = out[-1]
173195
else:
174196
diams = diameter
197+
ratio = 1.
198+
if restore_type is not None:
199+
imgs_dn = out[-1]
200+
ratio = diams / model.dn.diam_mean if "upsample" in restore_type else 1.
201+
diams = model.dn.diam_mean if "upsample" in restore_type and model.dn.diam_mean > diams else diams
202+
else:
203+
imgs_dn = None
175204
if args.exclude_on_edges:
176205
masks = utils.remove_edge_masks(masks)
177206
if not args.no_npy:
178-
io.masks_flows_to_seg(image, masks, flows, image_name,
179-
channels=channels, diams=diams)
207+
io.masks_flows_to_seg(image, masks, flows, image_name, imgs_restore=imgs_dn,
208+
channels=channels, diams=diams,
209+
restore_type=restore_type, ratio=1.)
180210
if saving_something:
181-
io.save_masks(image, masks, flows, image_name, png=args.save_png,
211+
io.save_masks(image, masks, flows, image_name,
212+
png=args.save_png,
182213
tif=args.save_tif, save_flows=args.save_flows,
183214
save_outlines=args.save_outlines,
184215
dir_above=args.dir_above, savedir=args.savedir,

cellpose/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def get_arg_parser():
6969
model_args.add_argument("--pretrained_model", required=False, default="cyto",
7070
type=str,
7171
help="model to use for running or starting training")
72+
model_args.add_argument("--restore_type", required=False, default=None,
73+
type=str,
74+
help="model to use for image restoration")
75+
model_args.add_argument("--chan2_restore", action="store_true",
76+
help="use nuclei restore model for second channel")
7277
model_args.add_argument(
7378
"--add_model", required=False, default=None, type=str,
7479
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")

cellpose/denoise.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -464,19 +464,19 @@ def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
464464
class CellposeDenoiseModel():
465465
""" model to run Cellpose and Image restoration """
466466
def __init__(self, gpu=False, pretrained_model=False, model_type=None,
467-
restore_type="denoise_cyto3", chan2_denoise=False,
467+
restore_type="denoise_cyto3", chan2_restore=False,
468468
device=None):
469469

470470
self.dn = DenoiseModel(gpu=gpu, model_type=restore_type,
471-
chan2=chan2_denoise, device=device)
471+
chan2=chan2_restore, device=device)
472472
self.cp = CellposeModel(gpu=gpu, model_type=model_type,
473473
pretrained_model=pretrained_model, device=device)
474474

475475
def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
476476
normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1,
477-
resample=True, invert=False, flow_threshold=0.4, cellprob_threshold=0.0,
478-
do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15,
479-
niter=None, interp=True):
477+
augment=False, resample=True, invert=False, flow_threshold=0.4,
478+
cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
479+
min_size=15, niter=None, interp=True):
480480
"""
481481
Restore array or list of images using the image restoration model, and then segment.
482482
@@ -510,6 +510,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
510510
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
511511
tile (bool, optional): tiles image to ensure GPU/CPU memory usage limited (recommended). Defaults to True.
512512
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
513+
augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
513514
resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
514515
invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
515516
flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
@@ -549,7 +550,8 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
549550
diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
550551
masks, flows, styles = self.cp.eval(img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
551552
normalize=normalize_params, rescale=rescale, diameter=diameter,
552-
tile=tile, tile_overlap=tile_overlap, resample=resample, invert=invert,
553+
tile=tile, tile_overlap=tile_overlap, augment=augment,
554+
resample=resample, invert=invert,
553555
flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold,
554556
do_3D=do_3D, anisotropy=anisotropy, stitch_threshold=stitch_threshold,
555557
min_size=min_size, niter=niter, interp=interp)
@@ -644,7 +646,7 @@ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
644646
)
645647
if chan2 and builtin:
646648
chan2_path = model_path(os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
647-
print(f"loading model for chan2: {os.path.split(str(chan2_path)[-1])}")
649+
print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
648650
self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
649651
mkldnn=self.mkldnn, max_pool=True,
650652
diam_mean=17.).to(self.device)

0 commit comments

Comments
 (0)