@@ -55,14 +55,20 @@ def boundary_scores(masks_true, masks_pred, scales):
5555 return precision , recall , fscore
5656
5757
58+ def _label_overlap (masks_true , masks_pred ):
59+ return csr_matrix ((np .ones ((masks_true .size ,), "int" ),
60+ (masks_true .flatten (), masks_pred .flatten ())),
61+ shape = (masks_true .max () + 1 , masks_pred .max () + 1 ))
62+
63+
5864def aggregated_jaccard_index (masks_true , masks_pred ):
59- """
60- AJI = intersection of all matched masks / union of all masks
61-
65+ """
66+ AJI = intersection of all matched masks / union of all masks
67+
6268 Args:
63- masks_true (list of np.ndarrays (int) or np.ndarray (int)):
69+ masks_true (list of np.ndarrays (int) or np.ndarray (int)):
6470 where 0=NO masks; 1,2... are mask labels
65- masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
71+ masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
6672 np.ndarray (int) where 0=NO masks; 1,2... are mask labels
6773
6874 Returns:
@@ -80,26 +86,26 @@ def aggregated_jaccard_index(masks_true, masks_pred):
8086
8187
8288def average_precision (masks_true , masks_pred , threshold = [0.5 , 0.75 , 0.9 ]):
83- """
89+ """
8490 Average precision estimation: AP = TP / (TP + FP + FN)
8591
8692 This function is based heavily on the *fast* stardist matching functions
8793 (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py)
8894
8995 Args:
90- masks_true (list of np.ndarrays (int) or np.ndarray (int)):
96+ masks_true (list of np.ndarrays (int) or np.ndarray (int)):
9197 where 0=NO masks; 1,2... are mask labels
92- masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
98+ masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
9399 np.ndarray (int) where 0=NO masks; 1,2... are mask labels
94100
95101 Returns:
96- ap (array [len(masks_true) x len(threshold)]):
102+ ap (array [len(masks_true) x len(threshold)]):
97103 average precision at thresholds
98- tp (array [len(masks_true) x len(threshold)]):
104+ tp (array [len(masks_true) x len(threshold)]):
99105 number of true positives at thresholds
100- fp (array [len(masks_true) x len(threshold)]):
106+ fp (array [len(masks_true) x len(threshold)]):
101107 number of false positives at thresholds
102- fn (array [len(masks_true) x len(threshold)]):
108+ fn (array [len(masks_true) x len(threshold)]):
103109 number of false negatives at thresholds
104110 """
105111 not_list = False
@@ -149,7 +155,7 @@ def _intersection_over_union(masks_true, masks_pred):
149155 How it works:
150156 The overlap matrix is a lookup table of the area of intersection
151157 between each set of labels (true and predicted). The true labels
152- are taken to be along axis 0, and the predicted labels are taken
158+ are taken to be along axis 0, and the predicted labels are taken
153159 to be along axis 1. The sum of the overlaps along axis 0 is thus
154160 an array giving the total overlap of the true labels with each of
155161 the predicted labels, and likewise the sum over axis 1 is the
@@ -159,13 +165,11 @@ def _intersection_over_union(masks_true, masks_pred):
159165 column vectors gives a 2D array with the areas of every label pair
160166 added together. This is equivalent to the union of the label areas
161167 except for the duplicated overlap area, so the overlap matrix is
162- subtracted to find the union matrix.
168+ subtracted to find the union matrix.
163169 """
164170 if masks_true .size != masks_pred .size :
165171 raise ValueError ("masks_true.size != masks_pred.size" )
166- overlap = csr_matrix ((np .ones ((masks_true .size ,), "int" ),
167- (masks_true .flatten (), masks_pred .flatten ())),
168- shape = (masks_true .max ()+ 1 , masks_pred .max ()+ 1 ))
172+ overlap = _label_overlap (masks_true , masks_pred )
169173 overlap = overlap .toarray ()
170174 n_pixels_pred = np .sum (overlap , axis = 0 , keepdims = True )
171175 n_pixels_true = np .sum (overlap , axis = 1 , keepdims = True )
0 commit comments