@@ -372,8 +372,8 @@ def test_correct_detections_with_keypoints():
372372 corrected_detections = correct_detections (
373373 detections = detections ,
374374 perspective_transformer = transformer ,
375- transformed_rect_width = [ 100 ] ,
376- transformed_rect_height = [ 100 ] ,
375+ transformed_rect_width = 100 ,
376+ transformed_rect_height = 100 ,
377377 )
378378
379379 # then
@@ -397,6 +397,34 @@ def test_warp_image():
397397 parent_metadata = ImageParentMetadata (parent_id = "test" ), numpy_image = dummy_image
398398 )
399399
400+ # when
401+ result = perspective_correction_block .run (
402+ images = [workflow_image_data ],
403+ predictions = [dummy_predictions ],
404+ perspective_polygons = [[[1 , 1 ], [99 , 1 ], [99 , 99 ], [1 , 99 ]]],
405+ transformed_rect_width = 200 ,
406+ transformed_rect_height = 200 ,
407+ extend_perspective_polygon_by_detections_anchor = None ,
408+ warp_image = True ,
409+ )
410+
411+ # then
412+ assert "warped_image" in result [0 ], "warped_image key must be present in the result"
413+ assert isinstance (
414+ result [0 ]["warped_image" ], WorkflowImageData
415+ ), f"warped_image must be of type WorkflowImageData"
416+
417+
418+ def test_warp_image_batch_dims ():
419+ # given
420+ dummy_image = np .random .randint (0 , 255 , (100 , 100 , 3 ), dtype = np .uint8 )
421+ dummy_predictions = sv .Detections (xyxy = np .array ([[10 , 10 , 20 , 20 ]]))
422+ perspective_correction_block = PerspectiveCorrectionBlockV1 ()
423+
424+ workflow_image_data = WorkflowImageData (
425+ parent_metadata = ImageParentMetadata (parent_id = "test" ), numpy_image = dummy_image
426+ )
427+
400428 # when
401429 result = perspective_correction_block .run (
402430 images = [workflow_image_data ],
0 commit comments