Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions inference/core/entities/requests/sam3.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def _validate_box_labels(cls, labels, values):
raise ValueError("box_labels must match boxes length when provided")
return labels

output_prob_thresh: Optional[float] = Field(
default=None,
description="Score threshold for this prompt's outputs. Overrides request-level threshold if set.",
)


class Sam3InferenceRequest(BaseRequest):
"""SAM3 inference request.
Expand Down Expand Up @@ -93,6 +98,17 @@ class Sam3SegmentationRequest(Sam3InferenceRequest):
description="List of prompts (text and/or visual)", min_items=1
)

nms_iou_threshold: Optional[float] = Field(
default=None,
description="IoU threshold for cross-prompt NMS. If None, NMS is disabled. Must be in [0.0, 1.0] when set.",
)

@validator("nms_iou_threshold")
def _validate_nms_iou_threshold(cls, v):
if v is not None and (v < 0.0 or v > 1.0):
raise ValueError("nms_iou_threshold must be between 0.0 and 1.0")
return v

@validator("prompts")
def _validate_prompts(cls, prompts: List[Sam3Prompt]):
if not prompts or len(prompts) == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import requests
from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, validator

from inference.core.entities.responses.inference import (
InferenceResponseImage,
Expand All @@ -30,6 +30,7 @@
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
BOOLEAN_KIND,
FLOAT_KIND,
IMAGE_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
Expand Down Expand Up @@ -88,6 +89,48 @@ class BlockManifest(WorkflowBlockManifest):
default=0.5, description="Threshold for predicted mask scores", examples=[0.3]
)

confidence_thresholds: Optional[
Union[List[float], str, Selector(kind=[LIST_OF_VALUES_KIND, STRING_KIND])]
] = Field(
default=None,
title="Per-Class Confidence Thresholds",
description="List of thresholds per class (must match class_names length) or comma-separated string",
examples=[[0.3, 0.5, 0.7], "0.3,0.5,0.7"],
)

apply_nms: Union[Selector(kind=[BOOLEAN_KIND]), bool] = Field(
default=True,
title="Apply NMS",
description="Whether to apply Non-Maximum Suppression across prompts",
)

nms_iou_threshold: Union[Selector(kind=[FLOAT_KIND]), float] = Field(
default=0.9,
title="NMS IoU Threshold",
description="IoU threshold for cross-prompt NMS. Must be in [0.0, 1.0]",
examples=[0.5, 0.9],
)

@validator("nms_iou_threshold")
def _validate_nms_iou_threshold(cls, v):
if isinstance(v, (int, float)) and (v < 0.0 or v > 1.0):
raise ValueError("nms_iou_threshold must be between 0.0 and 1.0")
return v

@validator("confidence_thresholds", pre=True)
def _parse_confidence_thresholds(cls, v):
if v is None:
return None
if isinstance(v, str):
# Parse comma-separated string to list of floats
try:
return [float(x.strip()) for x in v.split(",") if x.strip()]
except ValueError:
raise ValueError(
"confidence_thresholds string must be comma-separated floats"
)
return v

@classmethod
def get_parameters_accepting_batches(cls) -> List[str]:
return ["images", "boxes"]
Expand Down Expand Up @@ -131,6 +174,9 @@ def run(
images: Batch[WorkflowImageData],
class_names: Optional[Union[List[str], str]],
threshold: float,
confidence_thresholds: Optional[Union[List[float], str]] = None,
apply_nms: bool = True,
nms_iou_threshold: float = 0.9,
) -> BlockResult:

if isinstance(class_names, str):
Expand All @@ -140,17 +186,43 @@ def run(
else:
raise ValueError(f"Invalid class names type: {type(class_names)}")

# Parse confidence_thresholds if string
parsed_thresholds = None
if confidence_thresholds is not None:
if isinstance(confidence_thresholds, str):
parsed_thresholds = [
float(x.strip())
for x in confidence_thresholds.split(",")
if x.strip()
]
else:
parsed_thresholds = confidence_thresholds

# Validate length matches class_names
if parsed_thresholds and class_names:
if len(parsed_thresholds) != len(class_names):
raise ValueError(
f"confidence_thresholds length ({len(parsed_thresholds)}) "
f"must match class_names length ({len(class_names)})"
)

return self.run_via_request(
images=images,
class_names=class_names,
threshold=threshold,
confidence_thresholds=parsed_thresholds,
apply_nms=apply_nms,
nms_iou_threshold=nms_iou_threshold,
)

def run_via_request(
self,
images: Batch[WorkflowImageData],
class_names: Optional[List[str]],
threshold: float,
confidence_thresholds: Optional[List[float]] = None,
apply_nms: bool = True,
nms_iou_threshold: float = 0.9,
) -> BlockResult:
predictions = []
if class_names is None:
Expand All @@ -168,8 +240,12 @@ def run_via_request(

# Build unified prompt list payloads for HTTP
http_prompts: List[dict] = []
for class_name in class_names:
http_prompts.append({"type": "text", "text": class_name})
for idx, class_name in enumerate(class_names):
prompt_data = {"type": "text", "text": class_name}
# Add per-prompt threshold if confidence_thresholds is set
if confidence_thresholds is not None and idx < len(confidence_thresholds):
prompt_data["output_prob_thresh"] = confidence_thresholds[idx]
http_prompts.append(prompt_data)

# Prepare image for remote API (base64)
http_image = {"type": "base64", "value": single_image.base64_image}
Expand All @@ -178,6 +254,7 @@ def run_via_request(
"image": http_image,
"prompts": http_prompts,
"output_prob_thresh": threshold,
"nms_iou_threshold": nms_iou_threshold if apply_nms else None,
}

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import requests
import supervision as sv
from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, validator

from inference.core import logger
from inference.core.entities.requests.sam3 import Sam3Prompt, Sam3SegmentationRequest
Expand Down Expand Up @@ -116,6 +116,48 @@ class BlockManifest(WorkflowBlockManifest):
default=0.5, description="Threshold for predicted mask scores", examples=[0.3]
)

confidence_thresholds: Optional[
Union[List[float], str, Selector(kind=[LIST_OF_VALUES_KIND, STRING_KIND])]
] = Field(
default=None,
title="Per-Class Confidence Thresholds",
description="List of thresholds per class (must match class_names length) or comma-separated string",
examples=[[0.3, 0.5, 0.7], "0.3,0.5,0.7"],
)

apply_nms: Union[Selector(kind=[BOOLEAN_KIND]), bool] = Field(
default=True,
title="Apply NMS",
description="Whether to apply Non-Maximum Suppression across prompts",
)

nms_iou_threshold: Union[Selector(kind=[FLOAT_KIND]), float] = Field(
default=0.9,
title="NMS IoU Threshold",
description="IoU threshold for cross-prompt NMS. Must be in [0.0, 1.0]",
examples=[0.5, 0.9],
)

@validator("nms_iou_threshold")
def _validate_nms_iou_threshold(cls, v):
if isinstance(v, (int, float)) and (v < 0.0 or v > 1.0):
raise ValueError("nms_iou_threshold must be between 0.0 and 1.0")
return v

@validator("confidence_thresholds", pre=True)
def _parse_confidence_thresholds(cls, v):
if v is None:
return None
if isinstance(v, str):
# Parse comma-separated string to list of floats
try:
return [float(x.strip()) for x in v.split(",") if x.strip()]
except ValueError:
raise ValueError(
"confidence_thresholds string must be comma-separated floats"
)
return v

@classmethod
def get_parameters_accepting_batches(cls) -> List[str]:
return ["images", "boxes"]
Expand Down Expand Up @@ -160,6 +202,9 @@ def run(
model_id: str,
class_names: Optional[Union[List[str], str]],
threshold: float,
confidence_thresholds: Optional[Union[List[float], str]] = None,
apply_nms: bool = True,
nms_iou_threshold: float = 0.9,
) -> BlockResult:

if isinstance(class_names, str):
Expand All @@ -169,6 +214,26 @@ def run(
else:
raise ValueError(f"Invalid class names type: {type(class_names)}")

# Parse confidence_thresholds if string
parsed_thresholds = None
if confidence_thresholds is not None:
if isinstance(confidence_thresholds, str):
parsed_thresholds = [
float(x.strip())
for x in confidence_thresholds.split(",")
if x.strip()
]
else:
parsed_thresholds = confidence_thresholds

# Validate length matches class_names
if parsed_thresholds and class_names:
if len(parsed_thresholds) != len(class_names):
raise ValueError(
f"confidence_thresholds length ({len(parsed_thresholds)}) "
f"must match class_names length ({len(class_names)})"
)

exec_mode = self._step_execution_mode
if SAM3_EXEC_MODE == "local":
exec_mode = self._step_execution_mode
Expand All @@ -188,13 +253,19 @@ def run(
model_id=model_id,
class_names=class_names,
threshold=threshold,
confidence_thresholds=parsed_thresholds,
apply_nms=apply_nms,
nms_iou_threshold=nms_iou_threshold,
)
elif exec_mode is StepExecutionMode.REMOTE:
logger.debug(f"Running SAM3 remotely")
return self.run_via_request(
images=images,
class_names=class_names,
threshold=threshold,
confidence_thresholds=parsed_thresholds,
apply_nms=apply_nms,
nms_iou_threshold=nms_iou_threshold,
)
else:
raise ValueError(
Expand All @@ -207,6 +278,9 @@ def run_locally(
model_id: str,
class_names: Optional[List[str]],
threshold: float,
confidence_thresholds: Optional[List[float]] = None,
apply_nms: bool = True,
nms_iou_threshold: float = 0.9,
) -> BlockResult:
predictions = []
if class_names is None:
Expand All @@ -227,8 +301,14 @@ def run_locally(

# Build unified prompt list: one per class name
unified_prompts: List[Sam3Prompt] = []
for class_name in class_names:
unified_prompts.append(Sam3Prompt(type="text", text=class_name))
for idx, class_name in enumerate(class_names):
# Add per-prompt threshold if confidence_thresholds is set
prompt_thresh = None
if confidence_thresholds is not None and idx < len(confidence_thresholds):
prompt_thresh = confidence_thresholds[idx]
unified_prompts.append(
Sam3Prompt(type="text", text=class_name, output_prob_thresh=prompt_thresh)
)

# Single batched request with all prompts
inference_request = Sam3SegmentationRequest(
Expand All @@ -237,6 +317,7 @@ def run_locally(
api_key=self._api_key,
prompts=unified_prompts,
output_prob_thresh=threshold,
nms_iou_threshold=nms_iou_threshold if apply_nms else None,
)

sam3_response = self._model_manager.infer_from_request_sync(
Expand Down Expand Up @@ -281,6 +362,9 @@ def run_via_request(
images: Batch[WorkflowImageData],
class_names: Optional[List[str]],
threshold: float,
confidence_thresholds: Optional[List[float]] = None,
apply_nms: bool = True,
nms_iou_threshold: float = 0.9,
) -> BlockResult:
predictions = []
if class_names is None:
Expand All @@ -298,8 +382,12 @@ def run_via_request(

# Build unified prompt list payloads for HTTP
http_prompts: List[dict] = []
for class_name in class_names:
http_prompts.append({"type": "text", "text": class_name})
for idx, class_name in enumerate(class_names):
prompt_data = {"type": "text", "text": class_name}
# Add per-prompt threshold if confidence_thresholds is set
if confidence_thresholds is not None and idx < len(confidence_thresholds):
prompt_data["output_prob_thresh"] = confidence_thresholds[idx]
http_prompts.append(prompt_data)

# Prepare image for remote API (base64)
http_image = {"type": "base64", "value": single_image.base64_image}
Expand All @@ -308,6 +396,7 @@ def run_via_request(
"image": http_image,
"prompts": http_prompts,
"output_prob_thresh": threshold,
"nms_iou_threshold": nms_iou_threshold if apply_nms else None,
}

try:
Expand Down
Loading
Loading