diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 7b6cdf3f24ed..ac4b1676262b 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -15,7 +15,7 @@ from collections import defaultdict from collections.abc import Collection, Iterable from math import ceil -from typing import Optional, Union +from typing import Any, Optional, Union, overload import numpy as np @@ -26,7 +26,7 @@ get_image_size, infer_channel_dimension_format, ) -from .utils import ExplicitEnum, TensorType, is_torch_tensor +from .utils import ExplicitEnum, is_torch_tensor from .utils.import_utils import ( is_torch_available, is_vision_available, @@ -547,7 +547,15 @@ def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray: # 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py -def center_to_corners_format(bboxes_center: TensorType) -> TensorType: +@overload +def center_to_corners_format(bboxes_center: "torch.Tensor") -> "torch.Tensor": ... + + +@overload +def center_to_corners_format(bboxes_center: np.ndarray) -> np.ndarray: ... + + +def center_to_corners_format(bboxes_center: "torch.Tensor | np.ndarray") -> Any: """ Converts bounding boxes from center format to corners format. @@ -590,7 +598,15 @@ def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray: return bboxes_center -def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: +@overload +def corners_to_center_format(bboxes_corners: "torch.Tensor") -> "torch.Tensor": ... + + +@overload +def corners_to_center_format(bboxes_corners: np.ndarray) -> np.ndarray: ... + + +def corners_to_center_format(bboxes_corners: "torch.Tensor | np.ndarray") -> Any: """ Converts bounding boxes from corners format to center format.