diff --git a/inference/core/workflows/execution_engine/introspection/schema_parser.py b/inference/core/workflows/execution_engine/introspection/schema_parser.py index 2dc7b40b8c..3b24c98395 100644 --- a/inference/core/workflows/execution_engine/introspection/schema_parser.py +++ b/inference/core/workflows/execution_engine/introspection/schema_parser.py @@ -420,6 +420,19 @@ def retrieve_selectors_from_union_definition( + union_definition.get(ONE_OF_KEY, []) + union_definition.get(ALL_OF_KEY, []) ) + # Check if any union variant is an array or dict type + # This handles Union[List[...], Selector(...)] patterns + contains_array_type = False + contains_dict_type = False + for type_definition in union_types: + if type_definition.get("type") == "array" and ITEMS_KEY in type_definition: + contains_array_type = True + if ( + type_definition.get("type") == "object" + and "additionalProperties" in type_definition + ): + contains_dict_type = True + results = [] for type_definition in union_types: result = retrieve_selectors_from_simple_property( @@ -468,8 +481,8 @@ def retrieve_selectors_from_union_definition( property_name=property_name, property_description=property_description, allowed_references=merged_references, - is_list_element=is_list_element, - is_dict_element=is_dict_element, + is_list_element=is_list_element or contains_array_type, + is_dict_element=is_dict_element or contains_dict_type, dimensionality_offset=property_dimensionality_offset, is_dimensionality_reference_property=is_dimensionality_reference_property, ) diff --git a/inference/core/workflows/execution_engine/introspection/selectors_parser.py b/inference/core/workflows/execution_engine/introspection/selectors_parser.py index b8d4b788bf..af7ce80882 100644 --- a/inference/core/workflows/execution_engine/introspection/selectors_parser.py +++ b/inference/core/workflows/execution_engine/introspection/selectors_parser.py @@ -23,27 +23,30 @@ def get_step_selectors( step_manifest=step_manifest, property_name=property_name, ) - if selector_definition.is_list_element: + # Check runtime type to handle Union[List[...], Selector(...)] patterns + # where the actual value determines which path to take + if selector_definition.is_list_element and isinstance(property_value, list): selectors = retrieve_selectors_from_array( step_name=step_manifest.name, property_value=property_value, selector_definition=selector_definition, ) result.extend(selectors) - elif selector_definition.is_dict_element: + elif selector_definition.is_dict_element and isinstance(property_value, dict): selectors = retrieve_selectors_from_dictionary( step_name=step_manifest.name, property_value=property_value, selector_definition=selector_definition, ) result.extend(selectors) - else: + elif is_selector(property_value): selector = retrieve_selector_from_simple_property( step_name=step_manifest.name, property_value=property_value, selector_definition=selector_definition, ) result.append(selector) + # If none of the above, property_value is not a selector return [r for r in result if r is not None] diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_union_list_selector.py b/tests/workflows/integration_tests/execution/test_workflow_with_union_list_selector.py new file mode 100644 index 0000000000..53f49da606 --- /dev/null +++ b/tests/workflows/integration_tests/execution/test_workflow_with_union_list_selector.py @@ -0,0 +1,167 @@ +"""Integration tests for Union[List[...], Selector(...)] patterns in workflow blocks. + +This test verifies that blocks can properly handle properties defined as +Union[List[T], Selector(...)] where the value can be either: +- A literal list: ["tag1", "tag2"] +- A selector to a list: $inputs.tags +- A mixed list with literals and selectors: ["literal", "$inputs.tag"] +""" + +import numpy as np +import pytest + +from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS +from inference.core.managers.base import ModelManager +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.execution_engine.core import ExecutionEngine + + +WORKFLOW_WITH_SELECTOR_TO_LIST = { + "version": "1.0", + "inputs": [ + {"type": "WorkflowImage", "name": "image"}, + { + "type": "WorkflowParameter", + "name": "classes_to_consider", + }, + ], + "steps": [ + { + "type": "RoboflowObjectDetectionModel", + "name": "detection", + "image": "$inputs.image", + "model_id": "yolov8n-640", + }, + { + "type": "DetectionsConsensus", + "name": "consensus", + "predictions_batches": ["$steps.detection.predictions"], + "required_votes": 1, + "classes_to_consider": "$inputs.classes_to_consider", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "predictions", + "selector": "$steps.consensus.predictions", + } + ], +} + + +WORKFLOW_WITH_LITERAL_LIST = { + "version": "1.0", + "inputs": [ + {"type": "WorkflowImage", "name": "image"}, + ], + "steps": [ + { + "type": "RoboflowObjectDetectionModel", + "name": "detection", + "image": "$inputs.image", + "model_id": "yolov8n-640", + }, + { + "type": "DetectionsConsensus", + "name": "consensus", + "predictions_batches": ["$steps.detection.predictions"], + "required_votes": 1, + "classes_to_consider": ["person"], + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "predictions", + "selector": "$steps.consensus.predictions", + } + ], +} + + +def test_union_list_selector_with_selector_to_list( + model_manager: ModelManager, + crowd_image: np.ndarray, +) -> None: + """Test Union[List[str], Selector(...)] when using a selector to a list.""" + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image": crowd_image, + "classes_to_consider": ["person"], + } + ) + + # then + assert isinstance(result, list), "Expected list of results" + assert len(result) == 1, "Expected single result" + assert "predictions" in result[0], "Expected predictions in output" + # Verify that the selector was properly resolved and the workflow executed + + +def test_union_list_selector_with_literal_list( + model_manager: ModelManager, + crowd_image: np.ndarray, +) -> None: + """Test Union[List[str], Selector(...)] when using a literal list.""" + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_LITERAL_LIST, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image": crowd_image, + } + ) + + # then + assert isinstance(result, list), "Expected list of results" + assert len(result) == 1, "Expected single result" + assert "predictions" in result[0], "Expected predictions in output" + # Verify that the literal list was properly handled + + +def test_union_list_selector_validates_type_mismatch( + model_manager: ModelManager, + crowd_image: np.ndarray, +) -> None: + """Test that type validation catches invalid selector resolution.""" + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_SELECTOR_TO_LIST, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when/then - passing a string instead of a list should fail validation + with pytest.raises(Exception): # Should raise validation error + execution_engine.run( + runtime_parameters={ + "image": crowd_image, + "classes_to_consider": "person", # String instead of list + } + ) diff --git a/tests/workflows/unit_tests/execution_engine/introspection/test_schema_parser.py b/tests/workflows/unit_tests/execution_engine/introspection/test_schema_parser.py index 6d91895818..fce8ec41d2 100644 --- a/tests/workflows/unit_tests/execution_engine/introspection/test_schema_parser.py +++ b/tests/workflows/unit_tests/execution_engine/introspection/test_schema_parser.py @@ -6,8 +6,10 @@ from inference.core.workflows.execution_engine.entities.types import ( BOOLEAN_KIND, IMAGE_KIND, + LIST_OF_VALUES_KIND, OBJECT_DETECTION_PREDICTION_KIND, STRING_KIND, + Selector, StepOutputImageSelector, StepOutputSelector, StepSelector, @@ -531,3 +533,61 @@ def describe_outputs(cls) -> List[OutputDefinition]: ) }, ) + + +def test_parse_block_manifest_when_manifest_defines_union_of_list_or_selector() -> ( + None +): + """Test that Union[List[...], Selector(...)] properly sets is_list_element=True. + + This is a regression test for the bug where the schema parser would not detect + that a property could receive a list when defined as Union[List[T], Selector(...)]. + """ + # given + + class Manifest(WorkflowBlockManifest): + type: Literal["MyManifest"] + name: str = Field(description="name field") + tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field( + description="Tags can be a literal list or a selector to a list" + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + # when + manifest_metadata = parse_block_manifest(manifest_type=Manifest) + + # then + assert manifest_metadata == BlockManifestMetadata( + primitive_types={ + "name": PrimitiveTypeDefinition( + property_name="name", + property_description="name field", + type_annotation="str", + ), + "tags": PrimitiveTypeDefinition( + property_name="tags", + property_description="Tags can be a literal list or a selector to a list", + type_annotation="List[str]", + ), + }, + selectors={ + "tags": SelectorDefinition( + property_name="tags", + property_description="Tags can be a literal list or a selector to a list", + allowed_references=[ + ReferenceDefinition( + selected_element="any_data", + kind=[LIST_OF_VALUES_KIND], + points_to_batch={False}, + ), + ], + is_list_element=True, + is_dict_element=False, + dimensionality_offset=0, + is_dimensionality_reference_property=False, + ) + }, + ) diff --git a/tests/workflows/unit_tests/execution_engine/introspection/test_selectors_parser.py b/tests/workflows/unit_tests/execution_engine/introspection/test_selectors_parser.py index fe020eb78c..0d83fdc3d7 100644 --- a/tests/workflows/unit_tests/execution_engine/introspection/test_selectors_parser.py +++ b/tests/workflows/unit_tests/execution_engine/introspection/test_selectors_parser.py @@ -6,7 +6,9 @@ from inference.core.workflows.execution_engine.entities.types import ( BOOLEAN_KIND, IMAGE_KIND, + LIST_OF_VALUES_KIND, STRING_KIND, + Selector, StepOutputSelector, WorkflowImageSelector, WorkflowParameterSelector, @@ -160,3 +162,101 @@ def describe_outputs(cls) -> List[OutputDefinition]: assert ( selectors[0].definition.property_name == "param" ), "Selector definition must hold in terms of property name" + + +def test_get_step_selectors_when_union_of_list_or_selector_receives_selector() -> None: + """Test Union[List[str], Selector(...)] pattern when value is a selector to a list.""" + # given + + class Manifest(WorkflowBlockManifest): + type: Literal["UnionListTest"] + name: str = Field(description="name field") + tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field( + description="Tags can be a literal list or a selector to a list" + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + step_manifest = Manifest( + type="UnionListTest", name="my_step", tags="$inputs.tags" + ) + + # when + selectors = get_step_selectors(step_manifest=step_manifest) + + # then + assert len(selectors) == 1, "One selector should be found" + assert selectors[0].value == "$inputs.tags" + assert selectors[0].definition.property_name == "tags" + assert ( + selectors[0].definition.is_list_element is True + ), "Should be marked as list element due to Union[List[...], Selector(...)]" + assert selectors[0].index is None, "Direct selector should not have index" + + +def test_get_step_selectors_when_union_of_list_or_selector_receives_literal_list() -> ( + None +): + """Test Union[List[str], Selector(...)] pattern when value is a literal list.""" + # given + + class Manifest(WorkflowBlockManifest): + type: Literal["UnionListTest"] + name: str = Field(description="name field") + tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field( + description="Tags can be a literal list or a selector to a list" + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + step_manifest = Manifest( + type="UnionListTest", name="my_step", tags=["tag1", "tag2", "tag3"] + ) + + # when + selectors = get_step_selectors(step_manifest=step_manifest) + + # then + assert ( + len(selectors) == 0 + ), "No selectors should be found in literal list without selectors" + + +def test_get_step_selectors_when_union_of_list_or_selector_receives_mixed_list() -> ( + None +): + """Test Union[List[str], Selector(...)] when list contains both literals and selectors.""" + # given + + class Manifest(WorkflowBlockManifest): + type: Literal["UnionListTest"] + name: str = Field(description="name field") + tags: Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])] = Field( + description="Tags can be a literal list or a selector to a list" + ) + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [] + + step_manifest = Manifest( + type="UnionListTest", + name="my_step", + tags=["literal_tag", "$inputs.tag", "$inputs.another_tag"], + ) + + # when + selectors = get_step_selectors(step_manifest=step_manifest) + + # then + assert len(selectors) == 2, "Two selectors should be found in the mixed list" + assert selectors[0].value == "$inputs.tag" + assert selectors[0].index == 1, "First selector is at index 1" + assert selectors[1].value == "$inputs.another_tag" + assert selectors[1].index == 2, "Second selector is at index 2" + assert selectors[0].definition.is_list_element is True + assert selectors[1].definition.is_list_element is True