Skip to content

Commit 212bbde

Browse files
authored
fix: 🐞 add optional dependencies and improve package management in models and late import for torch to fix ci (#1251)
Signed-off-by: Onuralp SEZER <[email protected]>
1 parent 838f0da commit 212bbde

File tree

5 files changed

+28
-25
lines changed

5 files changed

+28
-25
lines changed

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ maintainers = [
4747
[project.scripts]
4848
sahi = "sahi.cli:app"
4949

50-
5150
[dependency-groups]
5251
dev = [
5352
"pytest>=7.2.2,<9.0.0",
@@ -78,6 +77,14 @@ build = [
7877
]
7978

8079
[project.optional-dependencies]
80+
yolov5 = ["yolov5>=6.0.0,<8.0.0"]
81+
yolo = ["ultralytics>=8.0.0"]
82+
torch = ["torch", "torchvision"]
83+
transformers = ["transformers>=4.49.0;python_version>='3.9'",]
84+
roboflow = ["inference>=0.51.5;python_version>='3.12'", "rfdetr>=1.1.0;python_version>='3.12'"]
85+
onnx = ["onnx;python_version>='3.10'", "onnxruntime;python_version>='3.10'"]
86+
all = ["yolov5>=6.0.0,<8.0.0", "ultralytics>=8.0.0", "torch", "torchvision", "transformers>=4.49.0;python_version>='3.9'", "inference>=0.51.5;python_version>='3.12'", "rfdetr>=1.1.0;python_version>='3.12'", "onnx;python_version>='3.10'", "onnxruntime;python_version>='3.10'"]
87+
8188
mmdet = [
8289
# MMDetection dependencies - only for Python 3.11 and not on macOS ARM64
8390
"mmengine;python_version=='3.11' and (platform_system!='Darwin' or platform_machine!='arm64')",

sahi/models/ultralytics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import cv2
66
import numpy as np
7-
import torch
87

98
from sahi.logger import logger
109
from sahi.models.base import DetectionModel
@@ -72,6 +71,9 @@ def perform_inference(self, image: np.ndarray):
7271
"""
7372

7473
# Confirm model is loaded
74+
75+
import torch
76+
7577
if self.model is None:
7678
raise ValueError("Model is not loaded, load it by calling .load_model()")
7779

sahi/models/yolov5.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,17 @@
88
from sahi.models.base import DetectionModel
99
from sahi.prediction import ObjectPrediction
1010
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
11-
from sahi.utils.import_utils import check_package_minimum_version, check_requirements
11+
from sahi.utils.import_utils import check_package_minimum_version
1212

1313

1414
class Yolov5DetectionModel(DetectionModel):
15-
def check_dependencies(self) -> None:
16-
check_requirements(["torch", "yolov5"])
15+
def __init__(self, *args, **kwargs):
16+
existing_packages = getattr(self, "required_packages", None) or []
17+
self.required_packages = [*list(existing_packages), "yolov5", "torch"]
18+
super().__init__(*args, **kwargs)
1719

1820
def load_model(self):
1921
"""Detection model is initialized and set to self.model."""
20-
2122
import yolov5
2223

2324
try:
@@ -71,13 +72,8 @@ def num_categories(self):
7172
@property
7273
def has_mask(self):
7374
"""Returns if model output contains segmentation mask."""
74-
import yolov5
75-
from packaging import version
7675

77-
if version.parse(yolov5.__version__) < version.parse("6.2.0"):
78-
return False
79-
else:
80-
return False # fix when yolov5 supports segmentation models
76+
return False # fix when yolov5 supports segmentation models
8177

8278
@property
8379
def category_names(self):

sahi/predict.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,14 @@
33
import os
44
import time
55
from collections.abc import Generator
6-
7-
from PIL import Image
8-
9-
from sahi.logger import logger
10-
from sahi.utils.import_utils import is_available
11-
12-
# TODO: This does nothing for this module. The issue named here does not exist
13-
# https://github.com/obss/sahi/issues/526
14-
if is_available("torch"):
15-
import torch # noqa: F401
16-
176
from functools import cmp_to_key
187

198
import numpy as np
9+
from PIL import Image
2010
from tqdm import tqdm
2111

2212
from sahi.auto_model import AutoDetectionModel
13+
from sahi.logger import logger
2314
from sahi.models.base import DetectionModel
2415
from sahi.postprocess.combine import (
2516
GreedyNMMPostprocess,

sahi/utils/torch_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22

33
import re
44
from os import environ
5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
66

77
import numpy as np
8-
import torch
98
from PIL.Image import Image
109

10+
if TYPE_CHECKING:
11+
import torch
12+
1113

1214
def empty_cuda_cache() -> None:
15+
import torch
16+
1317
torch.cuda.empty_cache()
1418

1519

@@ -22,6 +26,8 @@ def to_float_tensor(img: np.ndarray | Image) -> torch.Tensor:
2226
Returns:
2327
torch.tensor
2428
"""
29+
import torch
30+
2531
nparray: np.ndarray
2632
if isinstance(img, np.ndarray):
2733
nparray = img
@@ -54,6 +60,7 @@ def select_device(device: str | None = None) -> torch.device:
5460
5561
Inspired by https://github.com/ultralytics/yolov5/blob/6371de8879e7ad7ec5283e8b95cc6dd85d6a5e72/utils/torch_utils.py#L107
5662
"""
63+
import torch
5764

5865
if device == "cuda" or device is None:
5966
device = "cuda:0"

0 commit comments

Comments
 (0)