Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from pydantic import ConfigDict, Field
from supervision import OverlapFilter, move_boxes, move_masks

from inference.core.workflows.core_steps.common.utils import (
attach_parents_coordinates_to_sv_detections,
)
from inference.core.workflows.execution_engine.constants import (
IMAGE_DIMENSIONS_KEY,
PARENT_COORDINATES_KEY,
PARENT_DIMENSIONS_KEY,
PARENT_ID_KEY,
ROOT_PARENT_COORDINATES_KEY,
SCALING_RELATIVE_TO_PARENT_KEY,
)
from inference.core.workflows.execution_engine.entities.base import (
Expand Down Expand Up @@ -133,15 +135,16 @@ def run(
overlap_filtering_strategy: Optional[Literal["none", "nms", "nmm"]],
iou_threshold: Optional[float],
) -> BlockResult:
# Use reference image to ensure all masks have the same dimensions
reference_height, reference_width = reference_image.numpy_image.shape[:2]
resolution_wh = (reference_width, reference_height)

re_aligned_predictions = []
for detections in predictions:
detections_copy = deepcopy(detections)
resolution_wh = retrieve_crop_wh(detections=detections_copy)
offset = retrieve_crop_offset(detections=detections_copy)
detections_copy = manage_crops_metadata(
detections=detections_copy,
offset=offset,
parent_id=reference_image.parent_metadata.parent_id,
detections=detections_copy, image=reference_image
)
re_aligned_detections = move_detections(
detections=detections_copy,
Expand All @@ -160,21 +163,6 @@ def run(
return {"predictions": merged.with_nmm(threshold=iou_threshold)}


def retrieve_crop_wh(detections: sv.Detections) -> Optional[Tuple[int, int]]:
if len(detections) == 0:
return None
if PARENT_DIMENSIONS_KEY not in detections.data:
raise RuntimeError(
f"Dimensions for crops is expected to be saved in data key {PARENT_DIMENSIONS_KEY} "
f"of sv.Detections, but could not be found. Probably block producing sv.Detections "
f"lack this part of implementation or has a bug."
)
return (
detections.data[PARENT_DIMENSIONS_KEY][0][1].item(),
detections.data[PARENT_DIMENSIONS_KEY][0][0].item(),
)


def retrieve_crop_offset(detections: sv.Detections) -> Optional[np.ndarray]:
if len(detections) == 0:
return None
Expand All @@ -189,15 +177,11 @@ def retrieve_crop_offset(detections: sv.Detections) -> Optional[np.ndarray]:

def manage_crops_metadata(
detections: sv.Detections,
offset: Optional[np.ndarray],
parent_id: str,
image: WorkflowImageData,
) -> sv.Detections:
if len(detections) == 0:
return detections
if offset is None:
raise ValueError(
"To process non-empty detections offset is needed, but not given"
)

if SCALING_RELATIVE_TO_PARENT_KEY in detections.data:
scale = detections[SCALING_RELATIVE_TO_PARENT_KEY][0]
if abs(scale - 1.0) > 1e-4:
Expand All @@ -208,12 +192,14 @@ def manage_crops_metadata(
f"scaling cannot be used in the meantime. This error probably indicate "
f"wrong step output plugged as input of this step."
)
if PARENT_COORDINATES_KEY in detections.data:
detections.data[PARENT_COORDINATES_KEY] -= offset
if ROOT_PARENT_COORDINATES_KEY in detections.data:
detections.data[ROOT_PARENT_COORDINATES_KEY] -= offset
detections.data[PARENT_ID_KEY] = np.array([parent_id] * len(detections))
return detections

height, width = image.numpy_image.shape[:2]
detections[IMAGE_DIMENSIONS_KEY] = np.array([[height, width]] * len(detections))

return attach_parents_coordinates_to_sv_detections(
detections=detections,
image=image,
)


def move_detections(
Expand Down
Loading