Skip to content

Commit 6578b4e

Browse files
committed
initial version of filter_segments_by_distance
1 parent 0bd6087 commit 6578b4e

File tree

4 files changed

+336
-1
lines changed

4 files changed

+336
-1
lines changed

docs/detection/utils/masks.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,9 @@ status: new
2222
</div>
2323

2424
:::supervision.detection.utils.masks.contains_multiple_segments
25+
26+
<div class="md-typeset">
27+
<h2><a href="#supervision.detection.utils.masks.filter_segments_by_distance">filter_segments_by_distance</a></h2>
28+
</div>
29+
30+
:::supervision.detection.utils.masks.filter_segments_by_distance

supervision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
contains_holes,
8989
contains_multiple_segments,
9090
move_masks,
91+
filter_segments_by_distance,
9192
)
9293
from supervision.detection.utils.polygons import (
9394
approximate_polygon,
@@ -219,6 +220,7 @@
219220
"draw_text",
220221
"edit_distance",
221222
"filter_polygons_by_area",
223+
"filter_segments_by_distance",
222224
"fuzzy_match_index",
223225
"get_coco_class_index_mapping",
224226
"get_polygon_center",

supervision/detection/utils/masks.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Literal
4+
35
import cv2
46
import numpy as np
57
import numpy.typing as npt
@@ -260,3 +262,138 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
260262
resized_masks = masks[:, yv, xv]
261263

262264
return resized_masks.reshape(masks.shape[0], new_height, new_width)
265+
266+
267+
def filter_segments_by_distance(
268+
mask: npt.NDArray[np.bool_],
269+
absolute_distance: float | None = 100.0,
270+
relative_distance: float | None = None,
271+
connectivity: int = 8,
272+
mode: Literal["edge", "centroid"] = "edge",
273+
) -> npt.NDArray[np.bool_]:
274+
"""
275+
Keep the largest connected component and any other components within a distance threshold.
276+
277+
Distance can be absolute in pixels or relative to the image diagonal.
278+
279+
Args:
280+
mask: Boolean mask HxW.
281+
absolute_distance: Max allowed distance in pixels to the main component.
282+
Ignored if `relative_distance` is provided.
283+
relative_distance: Fraction of the diagonal. If set, threshold = fraction * sqrt(H^2 + W^2).
284+
connectivity: Defines which neighboring pixels are considered connected.
285+
- 4-connectedness: Only orthogonal neighbors.
286+
```
287+
[ ][X][ ]
288+
[X][O][X]
289+
[ ][X][ ]
290+
```
291+
- 8-connectedness: Includes diagonal neighbors.
292+
```
293+
[X][X][X]
294+
[X][O][X]
295+
[X][X][X]
296+
```
297+
Default is 8.
298+
mode: Defines how distance between components is measured.
299+
- "edge": Uses distance between nearest edges (via distance transform).
300+
- "centroid": Uses distance between component centroids.
301+
302+
Returns:
303+
Boolean mask after filtering.
304+
305+
Examples:
306+
```python
307+
import numpy as np
308+
import supervision as sv
309+
310+
mask = np.array([
311+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
312+
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
313+
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
314+
[0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
315+
[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
316+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
317+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
318+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
319+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
320+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
321+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
322+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
323+
], dtype=bool)
324+
325+
sv.filter_segments_by_distance(
326+
mask,
327+
absolute_distance=2,
328+
mode="edge",
329+
connectivity=8
330+
).astype(int)
331+
332+
# np.array([
333+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
334+
# [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
335+
# [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
336+
# [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
337+
# [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
338+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
339+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
340+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
341+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
342+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
343+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
344+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
345+
# ], dtype=bool)
346+
347+
# The nearby 2×2 block at columns 6–7 is kept because its edge distance
348+
# is within 2 pixels. The distant block at columns 9–10 is removed.
349+
```
350+
"""
351+
if mask.dtype != bool:
352+
raise TypeError("mask must be boolean")
353+
354+
height, width = mask.shape
355+
if not np.any(mask):
356+
return mask.copy()
357+
358+
image = mask.astype(np.uint8)
359+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
360+
image, connectivity=connectivity
361+
)
362+
363+
if num_labels <= 1:
364+
return mask.copy()
365+
366+
areas = stats[1:, cv2.CC_STAT_AREA]
367+
main_label = 1 + int(np.argmax(areas))
368+
369+
if relative_distance is not None:
370+
diagonal = float(np.hypot(height, width))
371+
threshold = float(relative_distance) * diagonal
372+
else:
373+
threshold = float(absolute_distance)
374+
375+
keep_labels = np.zeros(num_labels, dtype=bool)
376+
keep_labels[main_label] = True
377+
378+
if mode == "centroid":
379+
differences = centroids[1:] - centroids[main_label]
380+
distances = np.sqrt(np.sum(differences**2, axis=1))
381+
nearby = 1 + np.where(distances <= threshold)[0]
382+
keep_labels[nearby] = True
383+
elif mode == "edge":
384+
main_mask = (labels == main_label).astype(np.uint8)
385+
inverse = 1 - main_mask
386+
distance_transform = cv2.distanceTransform(inverse, cv2.DIST_L2, 3)
387+
for label in range(1, num_labels):
388+
if label == main_label:
389+
continue
390+
component = labels == label
391+
if not np.any(component):
392+
continue
393+
min_distance = float(distance_transform[component].min())
394+
if min_distance <= threshold:
395+
keep_labels[label] = True
396+
else:
397+
raise ValueError("mode must be 'edge' or 'centroid'")
398+
399+
return keep_labels[labels]

test/detection/utils/test_masks.py

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
calculate_masks_centroids,
1111
contains_holes,
1212
contains_multiple_segments,
13-
move_masks,
13+
move_masks, filter_segments_by_distance,
1414
)
1515

1616

@@ -500,3 +500,193 @@ def test_contains_multiple_segments(
500500
with exception:
501501
result = contains_multiple_segments(mask=mask, connectivity=connectivity)
502502
assert result == expected_result
503+
504+
505+
@pytest.mark.parametrize(
506+
"mask, connectivity, mode, absolute_distance, relative_distance, expected_result, exception",
507+
[
508+
# single component, unchanged
509+
(
510+
np.array([
511+
[0, 0, 0, 0, 0, 0],
512+
[0, 1, 1, 1, 0, 0],
513+
[0, 1, 1, 1, 0, 0],
514+
[0, 1, 1, 1, 0, 0],
515+
[0, 0, 0, 0, 0, 0],
516+
[0, 0, 0, 0, 0, 0],
517+
], dtype=bool),
518+
8,
519+
"edge",
520+
2.0,
521+
None,
522+
np.array([
523+
[0, 0, 0, 0, 0, 0],
524+
[0, 1, 1, 1, 0, 0],
525+
[0, 1, 1, 1, 0, 0],
526+
[0, 1, 1, 1, 0, 0],
527+
[0, 0, 0, 0, 0, 0],
528+
[0, 0, 0, 0, 0, 0],
529+
], dtype=bool),
530+
DoesNotRaise(),
531+
),
532+
# two components, edge distance 2, kept with abs=1
533+
(
534+
np.array([
535+
[0, 0, 0, 0, 0, 0],
536+
[0, 1, 1, 1, 0, 1],
537+
[0, 1, 1, 1, 0, 1],
538+
[0, 1, 1, 1, 0, 1],
539+
[0, 0, 0, 0, 0, 0],
540+
[0, 0, 0, 0, 0, 0],
541+
], dtype=bool),
542+
8,
543+
"edge",
544+
2.0,
545+
None,
546+
np.array([
547+
[0, 0, 0, 0, 0, 0],
548+
[0, 1, 1, 1, 0, 1],
549+
[0, 1, 1, 1, 0, 1],
550+
[0, 1, 1, 1, 0, 1],
551+
[0, 0, 0, 0, 0, 0],
552+
[0, 0, 0, 0, 0, 0],
553+
], dtype=bool),
554+
DoesNotRaise(),
555+
),
556+
# centroid mode, far centroids, dropped with small relative threshold
557+
(
558+
np.array([
559+
[1, 1, 1, 0, 0, 0],
560+
[1, 1, 1, 0, 0, 0],
561+
[1, 1, 1, 0, 0, 0],
562+
[0, 0, 0, 0, 0, 0],
563+
[0, 0, 0, 1, 1, 1],
564+
[0, 0, 0, 1, 1, 1],
565+
], dtype=bool),
566+
8,
567+
"centroid",
568+
None,
569+
0.3, # diagonal ~8.49, threshold ~2.55, centroid gap ~4.24
570+
np.array([
571+
[1, 1, 1, 0, 0, 0],
572+
[1, 1, 1, 0, 0, 0],
573+
[1, 1, 1, 0, 0, 0],
574+
[0, 0, 0, 0, 0, 0],
575+
[0, 0, 0, 0, 0, 0],
576+
[0, 0, 0, 0, 0, 0],
577+
], dtype=bool),
578+
DoesNotRaise(),
579+
),
580+
# centroid mode, larger relative threshold, kept
581+
(
582+
np.array([
583+
[1, 1, 1, 0, 0, 0],
584+
[1, 1, 1, 0, 0, 0],
585+
[1, 1, 1, 0, 0, 0],
586+
[0, 0, 0, 0, 0, 0],
587+
[0, 0, 0, 1, 1, 1],
588+
[0, 0, 0, 1, 1, 1],
589+
], dtype=bool),
590+
8,
591+
"centroid",
592+
None,
593+
0.6, # diagonal ~8.49, threshold ~5.09, centroid gap ~4.24
594+
np.array([
595+
[1, 1, 1, 0, 0, 0],
596+
[1, 1, 1, 0, 0, 0],
597+
[1, 1, 1, 0, 0, 0],
598+
[0, 0, 0, 0, 0, 0],
599+
[0, 0, 0, 1, 1, 1],
600+
[0, 0, 0, 1, 1, 1],
601+
], dtype=bool),
602+
DoesNotRaise(),
603+
),
604+
# empty mask
605+
(
606+
np.zeros((4, 4), dtype=bool),
607+
4,
608+
"edge",
609+
2.0,
610+
None,
611+
np.zeros((4, 4), dtype=bool),
612+
DoesNotRaise(),
613+
),
614+
# full mask
615+
(
616+
np.ones((4, 4), dtype=bool),
617+
8,
618+
"centroid",
619+
None,
620+
0.2,
621+
np.ones((4, 4), dtype=bool),
622+
DoesNotRaise(),
623+
),
624+
# two components, pixel distance = 2, kept with abs=2
625+
(
626+
np.array([
627+
[0, 0, 0, 0, 0, 0, 0, 0],
628+
[0, 1, 1, 1, 0, 1, 1, 1],
629+
[0, 1, 1, 1, 0, 1, 1, 1],
630+
[0, 1, 1, 1, 0, 1, 1, 1],
631+
[0, 0, 0, 0, 0, 0, 0, 0],
632+
[0, 0, 0, 0, 0, 0, 0, 0],
633+
], dtype=bool),
634+
8,
635+
"edge",
636+
2.0, # was 1.0
637+
None,
638+
np.array([
639+
[0, 0, 0, 0, 0, 0, 0, 0],
640+
[0, 1, 1, 1, 0, 1, 1, 1],
641+
[0, 1, 1, 1, 0, 1, 1, 1],
642+
[0, 1, 1, 1, 0, 1, 1, 1],
643+
[0, 0, 0, 0, 0, 0, 0, 0],
644+
[0, 0, 0, 0, 0, 0, 0, 0],
645+
], dtype=bool),
646+
DoesNotRaise(),
647+
),
648+
649+
# two components, pixel distance = 3, dropped with abs=2
650+
(
651+
np.array([
652+
[0, 0, 0, 0, 0, 0, 0, 0, 0],
653+
[0, 1, 1, 1, 0, 0, 0, 1, 1],
654+
[0, 1, 1, 1, 0, 0, 0, 1, 1],
655+
[0, 1, 1, 1, 0, 0, 0, 1, 1],
656+
[0, 0, 0, 0, 0, 0, 0, 0, 0],
657+
[0, 0, 0, 0, 0, 0, 0, 0, 0],
658+
], dtype=bool),
659+
8,
660+
"edge",
661+
2.0, # keep threshold below 3 so the right blob is removed
662+
None,
663+
np.array([
664+
[0, 0, 0, 0, 0, 0, 0, 0, 0],
665+
[0, 1, 1, 1, 0, 0, 0, 0, 0],
666+
[0, 1, 1, 1, 0, 0, 0, 0, 0],
667+
[0, 1, 1, 1, 0, 0, 0, 0, 0],
668+
[0, 0, 0, 0, 0, 0, 0, 0, 0],
669+
[0, 0, 0, 0, 0, 0, 0, 0, 0],
670+
], dtype=bool),
671+
DoesNotRaise(),
672+
),
673+
]
674+
)
675+
def test_filter_segments_by_distance_sweep(
676+
mask: npt.NDArray,
677+
connectivity: int,
678+
mode: str,
679+
absolute_distance: float | None,
680+
relative_distance: float | None,
681+
expected_result: npt.NDArray | None,
682+
exception: Exception,
683+
) -> None:
684+
with exception:
685+
result = filter_segments_by_distance(
686+
mask=mask,
687+
connectivity=connectivity,
688+
mode=mode, # type: ignore[arg-type]
689+
absolute_distance=absolute_distance,
690+
relative_distance=relative_distance,
691+
)
692+
assert np.array_equal(result, expected_result)

0 commit comments

Comments
 (0)