diff --git a/supervision/detection/core.py b/supervision/detection/core.py index ffe5ed3fc..bda2e7de3 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -296,18 +296,24 @@ def from_ultralytics(cls, ultralytics_results) -> Detections: class_id=np.arange(len(ultralytics_results)), ) - class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int) - class_names = np.array([ultralytics_results.names[i] for i in class_id]) - return cls( - xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(), - confidence=ultralytics_results.boxes.conf.cpu().numpy(), - class_id=class_id, - mask=extract_ultralytics_masks(ultralytics_results), - tracker_id=ultralytics_results.boxes.id.int().cpu().numpy() - if ultralytics_results.boxes.id is not None - else None, - data={CLASS_NAME_DATA_FIELD: class_names}, - ) + if ( + hasattr(ultralytics_results, "boxes") + and ultralytics_results.boxes is not None + ): + class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int) + class_names = np.array([ultralytics_results.names[i] for i in class_id]) + return cls( + xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(), + confidence=ultralytics_results.boxes.conf.cpu().numpy(), + class_id=class_id, + mask=extract_ultralytics_masks(ultralytics_results), + tracker_id=ultralytics_results.boxes.id.int().cpu().numpy() + if ultralytics_results.boxes.id is not None + else None, + data={CLASS_NAME_DATA_FIELD: class_names}, + ) + + return cls.empty() @classmethod def from_yolo_nas(cls, yolo_nas_results) -> Detections: diff --git a/supervision/detection/utils/boxes.py b/supervision/detection/utils/boxes.py index 3b01fcb68..01fcdc769 100644 --- a/supervision/detection/utils/boxes.py +++ b/supervision/detection/utils/boxes.py @@ -95,24 +95,27 @@ def pad_boxes(xyxy: np.ndarray, px: int, py: int | None = None) -> np.ndarray: def denormalize_boxes( - normalized_xyxy: np.ndarray, + xyxy: np.ndarray, resolution_wh: tuple[int, int], normalization_factor: float = 1.0, ) -> np.ndarray: """ - Converts normalized bounding box coordinates to absolute pixel values. + Convert normalized bounding box coordinates to absolute pixel coordinates. + + Multiplies each bounding box coordinate by image size and divides by + `normalization_factor`, mapping values from normalized `[0, normalization_factor]` + to absolute pixel values for a given resolution. Args: - normalized_xyxy (np.ndarray): A numpy array of shape `(N, 4)` where each row - contains normalized coordinates in the format `(x_min, y_min, x_max, y_max)`, - with values between 0 and `normalization_factor`. - resolution_wh (Tuple[int, int]): A tuple `(width, height)` representing the - target image resolution. - normalization_factor (float, optional): The normalization range of the input - coordinates. Defaults to 1.0. + xyxy (`numpy.ndarray`): Normalized bounding boxes of shape `(N, 4)`, + where each row is `(x_min, y_min, x_max, y_max)`, values in + `[0, normalization_factor]`. + resolution_wh (`tuple[int, int]`): Target image resolution as `(width, height)`. + normalization_factor (`float`): Maximum value of input coordinate range. + Defaults to `1.0`. Returns: - np.ndarray: An array of shape `(N, 4)` with absolute coordinates in + (`numpy.ndarray`): Array of shape `(N, 4)` with absolute coordinates in `(x_min, y_min, x_max, y_max)` format. Examples: @@ -120,35 +123,39 @@ def denormalize_boxes( import numpy as np import supervision as sv - # Default normalization (0-1) - normalized_xyxy = np.array([ + xyxy = np.array([ [0.1, 0.2, 0.5, 0.6], - [0.3, 0.4, 0.7, 0.8] + [0.3, 0.4, 0.7, 0.8], + [0.2, 0.1, 0.6, 0.5] ]) - resolution_wh = (100, 200) - sv.denormalize_boxes(normalized_xyxy, resolution_wh) + + sv.denormalize_boxes(xyxy, (1280, 720)) # array([ - # [ 10., 40., 50., 120.], - # [ 30., 80., 70., 160.] + # [128., 144., 640., 432.], + # [384., 288., 896., 576.], + # [256., 72., 768., 360.] # ]) + ``` - # Custom normalization (0-100) - normalized_xyxy = np.array([ - [10., 20., 50., 60.], - [30., 40., 70., 80.] + ``` + import numpy as np + import supervision as sv + + xyxy = np.array([ + [256., 128., 768., 640.] ]) - sv.denormalize_boxes(normalized_xyxy, resolution_wh, normalization_factor=100.0) + + sv.denormalize_boxes(xyxy, (1280, 720), normalization_factor=1024.0) # array([ - # [ 10., 40., 50., 120.], - # [ 30., 80., 70., 160.] + # [320., 90., 960., 450.] # ]) ``` - """ # noqa E501 // docs + """ width, height = resolution_wh - result = normalized_xyxy.copy() + result = xyxy.copy() - result[[0, 2]] = (result[[0, 2]] * width) / normalization_factor - result[[1, 3]] = (result[[1, 3]] * height) / normalization_factor + result[:, [0, 2]] = (result[:, [0, 2]] * width) / normalization_factor + result[:, [1, 3]] = (result[:, [1, 3]] * height) / normalization_factor return result diff --git a/supervision/detection/vlm.py b/supervision/detection/vlm.py index 71207554e..2f9b60ddb 100644 --- a/supervision/detection/vlm.py +++ b/supervision/detection/vlm.py @@ -538,7 +538,7 @@ def from_google_gemini_2_0( return np.empty((0, 4)), None, np.empty((0,), dtype=str) labels = [] - boxes_list = [] + xyxy = [] for item in data: if "box_2d" not in item or "label" not in item: @@ -546,18 +546,16 @@ def from_google_gemini_2_0( labels.append(item["label"]) box = item["box_2d"] # Gemini bbox order is [y_min, x_min, y_max, x_max] - boxes_list.append( - denormalize_boxes( - np.array([box[1], box[0], box[3], box[2]]).astype(np.float64), - resolution_wh=(w, h), - normalization_factor=1000, - ) - ) + xyxy.append([box[1], box[0], box[3], box[2]]) - if not boxes_list: + if len(xyxy) == 0: return np.empty((0, 4)), None, np.empty((0,), dtype=str) - xyxy = np.array(boxes_list) + xyxy = denormalize_boxes( + np.array(xyxy, dtype=np.float64), + resolution_wh=(w, h), + normalization_factor=1000, + ) class_name = np.array(labels) class_id = None @@ -649,10 +647,10 @@ def from_google_gemini_2_5( box = item["box_2d"] # Gemini bbox order is [y_min, x_min, y_max, x_max] absolute_bbox = denormalize_boxes( - np.array([box[1], box[0], box[3], box[2]]).astype(np.float64), + np.array([[box[1], box[0], box[3], box[2]]]).astype(np.float64), resolution_wh=(w, h), normalization_factor=1000, - ) + )[0] boxes_list.append(absolute_bbox) if "mask" in item: @@ -735,7 +733,7 @@ def from_google_gemini_2_5( def from_moondream( result: dict, resolution_wh: tuple[int, int], -) -> tuple[np.ndarray]: +) -> np.ndarray: """ Parse and scale bounding boxes from moondream JSON output. @@ -773,7 +771,7 @@ def from_moondream( if "objects" not in result or not isinstance(result["objects"], list): return np.empty((0, 4), dtype=float) - denormalize_xyxy = [] + xyxy = [] for item in result["objects"]: if not all(k in item for k in ["x_min", "y_min", "x_max", "y_max"]): @@ -784,14 +782,12 @@ def from_moondream( x_max = item["x_max"] y_max = item["y_max"] - denormalize_xyxy.append( - denormalize_boxes( - np.array([x_min, y_min, x_max, y_max]).astype(np.float64), - resolution_wh=(w, h), - ) - ) + xyxy.append([x_min, y_min, x_max, y_max]) - if not denormalize_xyxy: + if len(xyxy) == 0: return np.empty((0, 4)) - return np.array(denormalize_xyxy, dtype=float) + return denormalize_boxes( + np.array(xyxy).astype(np.float64), + resolution_wh=(w, h), + ) diff --git a/test/detection/utils/test_boxes.py b/test/detection/utils/test_boxes.py index 919989287..66d0d999c 100644 --- a/test/detection/utils/test_boxes.py +++ b/test/detection/utils/test_boxes.py @@ -5,7 +5,12 @@ import numpy as np import pytest -from supervision.detection.utils.boxes import clip_boxes, move_boxes, scale_boxes +from supervision.detection.utils.boxes import ( + clip_boxes, + denormalize_boxes, + move_boxes, + scale_boxes, +) @pytest.mark.parametrize( @@ -142,3 +147,88 @@ def test_scale_boxes( with exception: result = scale_boxes(xyxy=xyxy, factor=factor) assert np.array_equal(result, expected_result) + + +@pytest.mark.parametrize( + "xyxy, resolution_wh, normalization_factor, expected_result, exception", + [ + ( + np.empty(shape=(0, 4)), + (1280, 720), + 1.0, + np.empty(shape=(0, 4)), + DoesNotRaise(), + ), # empty array + ( + np.array([[0.1, 0.2, 0.5, 0.6]]), + (1280, 720), + 1.0, + np.array([[128.0, 144.0, 640.0, 432.0]]), + DoesNotRaise(), + ), # single box with default normalization + ( + np.array([[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8]]), + (1280, 720), + 1.0, + np.array([[128.0, 144.0, 640.0, 432.0], [384.0, 288.0, 896.0, 576.0]]), + DoesNotRaise(), + ), # two boxes with default normalization + ( + np.array( + [[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8], [0.2, 0.1, 0.6, 0.5]] + ), + (1280, 720), + 1.0, + np.array( + [ + [128.0, 144.0, 640.0, 432.0], + [384.0, 288.0, 896.0, 576.0], + [256.0, 72.0, 768.0, 360.0], + ] + ), + DoesNotRaise(), + ), # three boxes - regression test for issue #1959 + ( + np.array([[10.0, 20.0, 50.0, 60.0]]), + (100, 200), + 100.0, + np.array([[10.0, 40.0, 50.0, 120.0]]), + DoesNotRaise(), + ), # single box with custom normalization factor + ( + np.array([[10.0, 20.0, 50.0, 60.0], [30.0, 40.0, 70.0, 80.0]]), + (100, 200), + 100.0, + np.array([[10.0, 40.0, 50.0, 120.0], [30.0, 80.0, 70.0, 160.0]]), + DoesNotRaise(), + ), # two boxes with custom normalization factor + ( + np.array([[0.0, 0.0, 1.0, 1.0]]), + (1920, 1080), + 1.0, + np.array([[0.0, 0.0, 1920.0, 1080.0]]), + DoesNotRaise(), + ), # full frame box + ( + np.array([[0.5, 0.5, 0.5, 0.5]]), + (640, 480), + 1.0, + np.array([[320.0, 240.0, 320.0, 240.0]]), + DoesNotRaise(), + ), # zero-area box (point) + ], +) +def test_denormalize_boxes( + xyxy: np.ndarray, + resolution_wh: tuple[int, int], + normalization_factor: float, + expected_result: np.ndarray, + exception: Exception, +) -> None: + with exception: + result = denormalize_boxes( + xyxy=xyxy, + resolution_wh=resolution_wh, + normalization_factor=normalization_factor, + ) + assert np.allclose(result, expected_result)