Skip to content

Commit 992d7b4

Browse files
authored
Merge pull request #1246 from mario-koddenbrock/fix_aggregated_jaccard_index
Fix: Add missing `_label_overlap` function and integrate into metrics
2 parents 8ef8804 + 8483650 commit 992d7b4

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

cellpose/metrics.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,20 @@ def boundary_scores(masks_true, masks_pred, scales):
5555
return precision, recall, fscore
5656

5757

58+
def _label_overlap(masks_true, masks_pred):
59+
return csr_matrix((np.ones((masks_true.size,), "int"),
60+
(masks_true.flatten(), masks_pred.flatten())),
61+
shape=(masks_true.max() + 1, masks_pred.max() + 1))
62+
63+
5864
def aggregated_jaccard_index(masks_true, masks_pred):
59-
"""
60-
AJI = intersection of all matched masks / union of all masks
61-
65+
"""
66+
AJI = intersection of all matched masks / union of all masks
67+
6268
Args:
63-
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
69+
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
6470
where 0=NO masks; 1,2... are mask labels
65-
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
71+
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
6672
np.ndarray (int) where 0=NO masks; 1,2... are mask labels
6773
6874
Returns:
@@ -80,26 +86,26 @@ def aggregated_jaccard_index(masks_true, masks_pred):
8086

8187

8288
def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
83-
"""
89+
"""
8490
Average precision estimation: AP = TP / (TP + FP + FN)
8591
8692
This function is based heavily on the *fast* stardist matching functions
8793
(https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py)
8894
8995
Args:
90-
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
96+
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
9197
where 0=NO masks; 1,2... are mask labels
92-
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
98+
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
9399
np.ndarray (int) where 0=NO masks; 1,2... are mask labels
94100
95101
Returns:
96-
ap (array [len(masks_true) x len(threshold)]):
102+
ap (array [len(masks_true) x len(threshold)]):
97103
average precision at thresholds
98-
tp (array [len(masks_true) x len(threshold)]):
104+
tp (array [len(masks_true) x len(threshold)]):
99105
number of true positives at thresholds
100-
fp (array [len(masks_true) x len(threshold)]):
106+
fp (array [len(masks_true) x len(threshold)]):
101107
number of false positives at thresholds
102-
fn (array [len(masks_true) x len(threshold)]):
108+
fn (array [len(masks_true) x len(threshold)]):
103109
number of false negatives at thresholds
104110
"""
105111
not_list = False
@@ -149,7 +155,7 @@ def _intersection_over_union(masks_true, masks_pred):
149155
How it works:
150156
The overlap matrix is a lookup table of the area of intersection
151157
between each set of labels (true and predicted). The true labels
152-
are taken to be along axis 0, and the predicted labels are taken
158+
are taken to be along axis 0, and the predicted labels are taken
153159
to be along axis 1. The sum of the overlaps along axis 0 is thus
154160
an array giving the total overlap of the true labels with each of
155161
the predicted labels, and likewise the sum over axis 1 is the
@@ -159,13 +165,11 @@ def _intersection_over_union(masks_true, masks_pred):
159165
column vectors gives a 2D array with the areas of every label pair
160166
added together. This is equivalent to the union of the label areas
161167
except for the duplicated overlap area, so the overlap matrix is
162-
subtracted to find the union matrix.
168+
subtracted to find the union matrix.
163169
"""
164170
if masks_true.size != masks_pred.size:
165171
raise ValueError("masks_true.size != masks_pred.size")
166-
overlap = csr_matrix((np.ones((masks_true.size,), "int"),
167-
(masks_true.flatten(), masks_pred.flatten())),
168-
shape=(masks_true.max()+1, masks_pred.max()+1))
172+
overlap = _label_overlap(masks_true, masks_pred)
169173
overlap = overlap.toarray()
170174
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
171175
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)

0 commit comments

Comments
 (0)