Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,19 @@ def retrieve_selectors_from_union_definition(
+ union_definition.get(ONE_OF_KEY, [])
+ union_definition.get(ALL_OF_KEY, [])
)
# Check if any union variant is an array or dict type
# This handles Union[List[...], Selector(...)] patterns
contains_array_type = False
contains_dict_type = False
for type_definition in union_types:
if type_definition.get("type") == "array" and ITEMS_KEY in type_definition:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use constants for key names

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also values

contains_array_type = True
if (
type_definition.get("type") == "object"
and "additionalProperties" in type_definition
):
contains_dict_type = True

results = []
for type_definition in union_types:
result = retrieve_selectors_from_simple_property(
Expand Down Expand Up @@ -468,8 +481,8 @@ def retrieve_selectors_from_union_definition(
property_name=property_name,
property_description=property_description,
allowed_references=merged_references,
is_list_element=is_list_element,
is_dict_element=is_dict_element,
is_list_element=is_list_element or contains_array_type,
is_dict_element=is_dict_element or contains_dict_type,
dimensionality_offset=property_dimensionality_offset,
is_dimensionality_reference_property=is_dimensionality_reference_property,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,30 @@ def get_step_selectors(
step_manifest=step_manifest,
property_name=property_name,
)
if selector_definition.is_list_element:
# Check runtime type to handle Union[List[...], Selector(...)] patterns
# where the actual value determines which path to take
if selector_definition.is_list_element and isinstance(property_value, list):
selectors = retrieve_selectors_from_array(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.extend(selectors)
elif selector_definition.is_dict_element:
elif selector_definition.is_dict_element and isinstance(property_value, dict):
selectors = retrieve_selectors_from_dictionary(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.extend(selectors)
else:
elif is_selector(property_value):
selector = retrieve_selector_from_simple_property(
step_name=step_manifest.name,
property_value=property_value,
selector_definition=selector_definition,
)
result.append(selector)
# If none of the above, property_value is not a selector
return [r for r in result if r is not None]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Integration tests for Union[List[...], Selector(...)] patterns in workflow blocks.

This test verifies that blocks can properly handle properties defined as
Union[List[T], Selector(...)] where the value can be either:
- A literal list: ["tag1", "tag2"]
- A selector to a list: $inputs.tags
- A mixed list with literals and selectors: ["literal", "$inputs.tag"]
"""

import numpy as np
import pytest

from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS
from inference.core.managers.base import ModelManager
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.execution_engine.core import ExecutionEngine


WORKFLOW_WITH_SELECTOR_TO_LIST = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
{
"type": "WorkflowParameter",
"name": "classes_to_consider",
},
],
"steps": [
{
"type": "RoboflowObjectDetectionModel",
"name": "detection",
"image": "$inputs.image",
"model_id": "yolov8n-640",
},
{
"type": "DetectionsConsensus",
"name": "consensus",
"predictions_batches": ["$steps.detection.predictions"],
"required_votes": 1,
"classes_to_consider": "$inputs.classes_to_consider",
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.consensus.predictions",
}
],
}


WORKFLOW_WITH_LITERAL_LIST = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
],
"steps": [
{
"type": "RoboflowObjectDetectionModel",
"name": "detection",
"image": "$inputs.image",
"model_id": "yolov8n-640",
},
{
"type": "DetectionsConsensus",
"name": "consensus",
"predictions_batches": ["$steps.detection.predictions"],
"required_votes": 1,
"classes_to_consider": ["person"],
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.consensus.predictions",
}
],
}


def test_union_list_selector_with_selector_to_list(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test Union[List[str], Selector(...)] when using a selector to a list."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": crowd_image,
"classes_to_consider": ["person"],
}
)

# then
assert isinstance(result, list), "Expected list of results"
assert len(result) == 1, "Expected single result"
assert "predictions" in result[0], "Expected predictions in output"
# Verify that the selector was properly resolved and the workflow executed


def test_union_list_selector_with_literal_list(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test Union[List[str], Selector(...)] when using a literal list."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_LITERAL_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when
result = execution_engine.run(
runtime_parameters={
"image": crowd_image,
}
)

# then
assert isinstance(result, list), "Expected list of results"
assert len(result) == 1, "Expected single result"
assert "predictions" in result[0], "Expected predictions in output"
# Verify that the literal list was properly handled


def test_union_list_selector_validates_type_mismatch(
model_manager: ModelManager,
crowd_image: np.ndarray,
) -> None:
"""Test that type validation catches invalid selector resolution."""
# given
workflow_init_parameters = {
"workflows_core.model_manager": model_manager,
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
}
execution_engine = ExecutionEngine.init(
workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST,
init_parameters=workflow_init_parameters,
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
)

# when/then - passing a string instead of a list should fail validation
with pytest.raises(Exception): # Should raise validation error
execution_engine.run(
runtime_parameters={
"image": crowd_image,
"classes_to_consider": "person", # String instead of list
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from inference.core.workflows.execution_engine.entities.types import (
BOOLEAN_KIND,
IMAGE_KIND,
LIST_OF_VALUES_KIND,
OBJECT_DETECTION_PREDICTION_KIND,
STRING_KIND,
Selector,
StepOutputImageSelector,
StepOutputSelector,
StepSelector,
Expand Down Expand Up @@ -531,3 +533,61 @@ def describe_outputs(cls) -> List[OutputDefinition]:
)
},
)


def test_parse_block_manifest_when_manifest_defines_union_of_list_or_selector() -> (
None
):
"""Test that Union[List[...], Selector(...)] properly sets is_list_element=True.

This is a regression test for the bug where the schema parser would not detect
that a property could receive a list when defined as Union[List[T], Selector(...)].
"""
# given

class Manifest(WorkflowBlockManifest):
type: Literal["MyManifest"]
name: str = Field(description="name field")
tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field(
description="Tags can be a literal list or a selector to a list"
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return []

# when
manifest_metadata = parse_block_manifest(manifest_type=Manifest)

# then
assert manifest_metadata == BlockManifestMetadata(
primitive_types={
"name": PrimitiveTypeDefinition(
property_name="name",
property_description="name field",
type_annotation="str",
),
"tags": PrimitiveTypeDefinition(
property_name="tags",
property_description="Tags can be a literal list or a selector to a list",
type_annotation="List[str]",
),
},
selectors={
"tags": SelectorDefinition(
property_name="tags",
property_description="Tags can be a literal list or a selector to a list",
allowed_references=[
ReferenceDefinition(
selected_element="any_data",
kind=[LIST_OF_VALUES_KIND],
points_to_batch={False},
),
],
is_list_element=True,
is_dict_element=False,
dimensionality_offset=0,
is_dimensionality_reference_property=False,
)
},
)
Loading