Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion grounded_sam2_local_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TEXT_PROMPT = "car. tire."
IMG_PATH = "notebooks/images/truck.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
SAM2_MODEL_CONFIG = "./sam2/configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"
BOX_THRESHOLD = 0.35
Expand Down
47 changes: 41 additions & 6 deletions sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import os

import torch
from hydra import compose
from hydra import compose, initialize_config_dir
from hydra.utils import instantiate
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf

from pathlib import Path

import sam2

# Check if the user is running Python from the parent directory of the sam2 repo
Expand Down Expand Up @@ -73,27 +76,59 @@ def build_sam2(
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
hydra_overrides_extra=None,
apply_postprocessing=True,
config_root=".",
**kwargs,
):
"""
Builds SAM 2 model with support for loading config files from any local directory.
Compatible with Hydra >= 1.3.2.

Args:
config_file (str): Name of the Hydra config YAML file (e.g., "config.yaml").
ckpt_path (str): Path to the model checkpoint to load.
device (str): Device to place the model on ("cuda" or "cpu").
mode (str): "eval" or "train" mode for the model.
hydra_overrides_extra (list): Additional Hydra override strings.
apply_postprocessing (bool): Whether to apply extra SAM-specific config tweaks.
config_root (str): Path to the directory containing Hydra config files.
**kwargs: Additional unused arguments (for compatibility).

Returns:
torch.nn.Module: The instantiated and loaded model.
"""

hydra_overrides_extra = hydra_overrides_extra.copy() if hydra_overrides_extra else []

if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
]
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)

config_root = str(Path(config_root).resolve())

try:
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()

with initialize_config_dir(config_dir=config_root, job_name="build_sam2"):
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
logging.info(f"Successfully loaded config: {config_file}")
except Exception as e:
logging.error(f"Failed to load config '{config_file}' from '{config_root}': {e}")
raise

OmegaConf.resolve(cfg)
model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path)
model = model.to(device)

if mode == "eval":
model.eval()

return model


Expand Down