Skip to content

Commit 9b69b3c

Browse files
Merge pull request #1654 from roboflow/fix/stitch-block-image-dimensions
Fix image dimensions in detections stitch block output
2 parents 3dee7a2 + b6db16f commit 9b69b3c

File tree

2 files changed

+441
-33
lines changed

2 files changed

+441
-33
lines changed

inference/core/workflows/core_steps/fusion/detections_stitch/v1.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from pydantic import ConfigDict, Field
77
from supervision import OverlapFilter, move_boxes, move_masks
88

9+
from inference.core.workflows.core_steps.common.utils import (
10+
attach_parents_coordinates_to_sv_detections,
11+
)
912
from inference.core.workflows.execution_engine.constants import (
13+
IMAGE_DIMENSIONS_KEY,
1014
PARENT_COORDINATES_KEY,
1115
PARENT_DIMENSIONS_KEY,
12-
PARENT_ID_KEY,
13-
ROOT_PARENT_COORDINATES_KEY,
1416
SCALING_RELATIVE_TO_PARENT_KEY,
1517
)
1618
from inference.core.workflows.execution_engine.entities.base import (
@@ -133,15 +135,16 @@ def run(
133135
overlap_filtering_strategy: Optional[Literal["none", "nms", "nmm"]],
134136
iou_threshold: Optional[float],
135137
) -> BlockResult:
138+
# Use reference image to ensure all masks have the same dimensions
139+
reference_height, reference_width = reference_image.numpy_image.shape[:2]
140+
resolution_wh = (reference_width, reference_height)
141+
136142
re_aligned_predictions = []
137143
for detections in predictions:
138144
detections_copy = deepcopy(detections)
139-
resolution_wh = retrieve_crop_wh(detections=detections_copy)
140145
offset = retrieve_crop_offset(detections=detections_copy)
141146
detections_copy = manage_crops_metadata(
142-
detections=detections_copy,
143-
offset=offset,
144-
parent_id=reference_image.parent_metadata.parent_id,
147+
detections=detections_copy, image=reference_image
145148
)
146149
re_aligned_detections = move_detections(
147150
detections=detections_copy,
@@ -160,21 +163,6 @@ def run(
160163
return {"predictions": merged.with_nmm(threshold=iou_threshold)}
161164

162165

163-
def retrieve_crop_wh(detections: sv.Detections) -> Optional[Tuple[int, int]]:
164-
if len(detections) == 0:
165-
return None
166-
if PARENT_DIMENSIONS_KEY not in detections.data:
167-
raise RuntimeError(
168-
f"Dimensions for crops is expected to be saved in data key {PARENT_DIMENSIONS_KEY} "
169-
f"of sv.Detections, but could not be found. Probably block producing sv.Detections "
170-
f"lack this part of implementation or has a bug."
171-
)
172-
return (
173-
detections.data[PARENT_DIMENSIONS_KEY][0][1].item(),
174-
detections.data[PARENT_DIMENSIONS_KEY][0][0].item(),
175-
)
176-
177-
178166
def retrieve_crop_offset(detections: sv.Detections) -> Optional[np.ndarray]:
179167
if len(detections) == 0:
180168
return None
@@ -189,15 +177,11 @@ def retrieve_crop_offset(detections: sv.Detections) -> Optional[np.ndarray]:
189177

190178
def manage_crops_metadata(
191179
detections: sv.Detections,
192-
offset: Optional[np.ndarray],
193-
parent_id: str,
180+
image: WorkflowImageData,
194181
) -> sv.Detections:
195182
if len(detections) == 0:
196183
return detections
197-
if offset is None:
198-
raise ValueError(
199-
"To process non-empty detections offset is needed, but not given"
200-
)
184+
201185
if SCALING_RELATIVE_TO_PARENT_KEY in detections.data:
202186
scale = detections[SCALING_RELATIVE_TO_PARENT_KEY][0]
203187
if abs(scale - 1.0) > 1e-4:
@@ -208,12 +192,14 @@ def manage_crops_metadata(
208192
f"scaling cannot be used in the meantime. This error probably indicate "
209193
f"wrong step output plugged as input of this step."
210194
)
211-
if PARENT_COORDINATES_KEY in detections.data:
212-
detections.data[PARENT_COORDINATES_KEY] -= offset
213-
if ROOT_PARENT_COORDINATES_KEY in detections.data:
214-
detections.data[ROOT_PARENT_COORDINATES_KEY] -= offset
215-
detections.data[PARENT_ID_KEY] = np.array([parent_id] * len(detections))
216-
return detections
195+
196+
height, width = image.numpy_image.shape[:2]
197+
detections[IMAGE_DIMENSIONS_KEY] = np.array([[height, width]] * len(detections))
198+
199+
return attach_parents_coordinates_to_sv_detections(
200+
detections=detections,
201+
image=image,
202+
)
217203

218204

219205
def move_detections(

0 commit comments

Comments
 (0)