Skip to content

Commit 1c7bd91

Browse files
davidsbatistasjrlAmnah199
authored
feat: adding agents back to the experimental repo (#326)
* adding agents back to the experimental repo * adding tests * wip * adding back dependencies * for this branch make the dependency on the haystack main * reverting dependency and commenting out test * reverting dependency and commenting out test * reverting dependency and commenting out test * feat: breakpoints structure (#330) * initial import * formating and renaming file * solving shadowing issues * solving shadowing issues * solving formating issues * adding type annotation * adding visit count to Tool * updating AgentBreakpoint * removing whitespace * renaming shadow variable names * improving eq in ToolBrekapoint * setting dependency from haystack-ai main * reverting pyproject.toml * syncing with new haystack release * typing * updating all agent related with with latest haystack release * more tpying issues * feat: adding break points to `Agent` (#328) * adding agents back to the experimental repo * wip: adding breakpoints PoC to Agents * wip * typing * formatting * pylint issues * fix typing * reverting pyproject.toml * fixing bug in component_visits * linting * adding/fixing tests * all tests fixed * cleaning up * adding breakpoints to async version * cleaning up * refactoring tests * adding more tests * cleaning up * formating issues * typing * cleaning * adding async tests * nit * refactoring breakpoints logic to avoid duplicated code * converting to staticmethod * updating Agent + tests to use Breakpoint datastrcutures * updating async tests to use Breakpoint datastrcutures * refactoring tests * refactoring tests * wip: Pipeline accepting Breakpoint and Agentbreakpoint * Pipeline accepting Breakpoint and Agentbreakpoint + fixing tests * fix validate_input issue * updating tests * wip: adding breakpoints PoC to Agents * WIP: testing breakpoints in agent within a pipeline * wip: adding test for multiple breakpoints in Agent and Pipeline * wip: adding test for multiple breakpoints in Agent and Pipeline * adding tests for agent in pipeline * nit * reverting validate * fixing agent in pipeline tests * Update haystack_experimental/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <[email protected]> * having one single BreakpointExecption * parameters * wip: fixing async tests * nit: cleaning * fixing imports * reverting back to a single Breakpoint * updating all tests due to reverting back to a single Breakpoint * updating all tests due to reverting back to a single Breakpoint * WIP: fixing tests * updating tests without break_on_first and properly catching all the execptions * updating integratiosn tests, reverting back to a single break point * enabling back linter rules * wip: resuming a pipeline agent, saving agent_name to state * cleaning up * all tests passing - working on pipeline agent resume from state * updating tests to consider mandatory agent_name when having an Agent breakpoint * serialisation with success of the pipeline where the agent is running plus the agent status serialisations into the the same resume JSON file * updating linting and types due to new changes * updating all tests due to new changes * resuming an Agent within a pipeline * replacing Path by PosixPath for Windows compatability * reverting back to Path * wip * cleaning up - small improvments * cleaning up validate pipeline * simplifying the save state * reducing parameters for save_state agent case * saving all the main pipeline data nested under a single key * more cleaning and small improvments * helper function to handle the processing the state of the host pipeline where agent is running * helper function to handle file saving * adding tests for agent in pipeline break and resume on a tool * refactoring pipeline.run * refactoring pipeline methods * removing hard-coded path for JSON resume files, using pytest tmpfile * wip * review commments/improvments * review commments/improvments + updating tests * Fix validate_breakpoint * moving static methods to breakpoint.py * Update haystack_experimental/core/pipeline/pipeline.py Co-authored-by: Amna Mubashar <[email protected]> * fixing LICENSE header in breakpoint.py * Update haystack_experimental/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <[email protected]> * Small changes * wip * attending PR comments * attending PR comments * homogenization of var names and function names --------- Co-authored-by: Sebastian Husch Lee <[email protected]> Co-authored-by: Amna Mubashar <[email protected]> Co-authored-by: Sebastian Husch Lee <[email protected]> --------- Co-authored-by: Sebastian Husch Lee <[email protected]> Co-authored-by: Amna Mubashar <[email protected]> Co-authored-by: Sebastian Husch Lee <[email protected]>
1 parent 5895cfc commit 1c7bd91

25 files changed

+4304
-190
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import sys
6+
from typing import TYPE_CHECKING
7+
8+
from lazy_imports import LazyImporter
9+
10+
_import_structure = {"agent": ["Agent"], "state": ["State"]}
11+
12+
if TYPE_CHECKING:
13+
from .agent import Agent
14+
from .state import State
15+
16+
else:
17+
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)

haystack_experimental/components/agents/agent.py

Lines changed: 739 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from .state import State
6+
from .state_utils import merge_lists, replace_values
7+
8+
__all__ = ["State", "merge_lists", "replace_values"]
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from copy import deepcopy
6+
from typing import Any, Callable, Dict, List, Optional
7+
8+
from haystack.dataclasses import ChatMessage
9+
from haystack.utils import _deserialize_value_with_schema, _serialize_value_with_schema
10+
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
11+
from haystack.utils.type_serialization import deserialize_type, serialize_type
12+
13+
from .state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values
14+
15+
16+
def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
17+
"""
18+
Convert a schema dictionary to a serializable format.
19+
20+
Converts each parameter's type and optional handler function into a serializable
21+
format using type and callable serialization utilities.
22+
23+
:param schema: Dictionary mapping parameter names to their type and handler configs
24+
:returns: Dictionary with serialized type and handler information
25+
"""
26+
serialized_schema = {}
27+
for param, config in schema.items():
28+
serialized_schema[param] = {"type": serialize_type(config["type"])}
29+
if config.get("handler"):
30+
serialized_schema[param]["handler"] = serialize_callable(config["handler"])
31+
32+
return serialized_schema
33+
34+
35+
def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
36+
"""
37+
Convert a serialized schema dictionary back to its original format.
38+
39+
Deserializes the type and optional handler function for each parameter from their
40+
serialized format back into Python types and callables.
41+
42+
:param schema: Dictionary containing serialized schema information
43+
:returns: Dictionary with deserialized type and handler configurations
44+
"""
45+
deserialized_schema = {}
46+
for param, config in schema.items():
47+
deserialized_schema[param] = {"type": deserialize_type(config["type"])}
48+
49+
if config.get("handler"):
50+
deserialized_schema[param]["handler"] = deserialize_callable(config["handler"])
51+
52+
return deserialized_schema
53+
54+
55+
def _validate_schema(schema: Dict[str, Any]) -> None:
56+
"""
57+
Validate that a schema dictionary meets all required constraints.
58+
59+
Checks that each parameter definition has a valid type field and that any handler
60+
specified is a callable function.
61+
62+
:param schema: Dictionary mapping parameter names to their type and handler configs
63+
:raises ValueError: If schema validation fails due to missing or invalid fields
64+
"""
65+
for param, definition in schema.items():
66+
if "type" not in definition:
67+
raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.")
68+
if not _is_valid_type(definition["type"]):
69+
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
70+
if definition.get("handler") is not None and not callable(definition["handler"]):
71+
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
72+
if param == "messages" and definition["type"] != List[ChatMessage]:
73+
raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")
74+
75+
76+
class State:
77+
"""
78+
A class that wraps a StateSchema and maintains an internal _data dictionary.
79+
80+
Each schema entry has:
81+
"parameter_name": {
82+
"type": SomeType,
83+
"handler": Optional[Callable[[Any, Any], Any]]
84+
}
85+
"""
86+
87+
def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
88+
"""
89+
Initialize a State object with a schema and optional data.
90+
91+
:param schema: Dictionary mapping parameter names to their type and handler configs.
92+
Type must be a valid Python type, and handler must be a callable function or None.
93+
If handler is None, the default handler for the type will be used. The default handlers are:
94+
- For list types: `haystack.agents.state.state_utils.merge_lists`
95+
- For all other types: `haystack.agents.state.state_utils.replace_values`
96+
:param data: Optional dictionary of initial data to populate the state
97+
"""
98+
_validate_schema(schema)
99+
self.schema = deepcopy(schema)
100+
if self.schema.get("messages") is None:
101+
self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
102+
self._data = data or {}
103+
104+
# Set default handlers if not provided in schema
105+
for definition in self.schema.values():
106+
# Skip if handler is already defined and not None
107+
if definition.get("handler") is not None:
108+
continue
109+
# Set default handler based on type
110+
if _is_list_type(definition["type"]):
111+
definition["handler"] = merge_lists
112+
else:
113+
definition["handler"] = replace_values
114+
115+
def get(self, key: str, default: Any = None) -> Any:
116+
"""
117+
Retrieve a value from the state by key.
118+
119+
:param key: Key to look up in the state
120+
:param default: Value to return if key is not found
121+
:returns: Value associated with key or default if not found
122+
"""
123+
return deepcopy(self._data.get(key, default))
124+
125+
def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None:
126+
"""
127+
Set or merge a value in the state according to schema rules.
128+
129+
Value is merged or overwritten according to these rules:
130+
- if handler_override is given, use that
131+
- else use the handler defined in the schema for 'key'
132+
133+
:param key: Key to store the value under
134+
:param value: Value to store or merge
135+
:param handler_override: Optional function to override the default merge behavior
136+
"""
137+
# If key not in schema, we throw an error
138+
definition = self.schema.get(key, None)
139+
if definition is None:
140+
raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}")
141+
142+
# Get current value from state and apply handler
143+
current_value = self._data.get(key, None)
144+
handler = handler_override or definition["handler"]
145+
self._data[key] = handler(current_value, value)
146+
147+
@property
148+
def data(self):
149+
"""
150+
All current data of the state.
151+
"""
152+
return self._data
153+
154+
def has(self, key: str) -> bool:
155+
"""
156+
Check if a key exists in the state.
157+
158+
:param key: Key to check for existence
159+
:returns: True if key exists in state, False otherwise
160+
"""
161+
return key in self._data
162+
163+
def to_dict(self):
164+
"""
165+
Convert the State object to a dictionary.
166+
"""
167+
serialized = {}
168+
serialized["schema"] = _schema_to_dict(self.schema)
169+
serialized["data"] = _serialize_value_with_schema(self._data)
170+
return serialized
171+
172+
@classmethod
173+
def from_dict(cls, data: Dict[str, Any]) -> "State":
174+
"""
175+
Convert a dictionary back to a State object.
176+
"""
177+
schema = _schema_from_dict(data.get("schema", {}))
178+
deserialized_data = _deserialize_value_with_schema(data.get("data", {}))
179+
return State(schema, deserialized_data)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import inspect
6+
from typing import Any, List, TypeVar, Union, get_origin
7+
8+
T = TypeVar("T")
9+
10+
11+
def _is_valid_type(obj: Any) -> bool:
12+
"""
13+
Check if an object is a valid type annotation.
14+
15+
Valid types include:
16+
- Normal classes (str, dict, CustomClass)
17+
- Generic types (List[str], Dict[str, int])
18+
- Union types (Union[str, int], Optional[str])
19+
20+
:param obj: The object to check
21+
:return: True if the object is a valid type annotation, False otherwise
22+
23+
Example usage:
24+
>>> _is_valid_type(str)
25+
True
26+
>>> _is_valid_type(List[int])
27+
True
28+
>>> _is_valid_type(Union[str, int])
29+
True
30+
>>> _is_valid_type(42)
31+
False
32+
"""
33+
# Handle Union types (including Optional)
34+
if hasattr(obj, "__origin__") and obj.__origin__ == Union:
35+
return True
36+
37+
# Handle normal classes and generic types
38+
return inspect.isclass(obj) or type(obj).__name__ in {"_GenericAlias", "GenericAlias"}
39+
40+
41+
def _is_list_type(type_hint: Any) -> bool:
42+
"""
43+
Check if a type hint represents a list type.
44+
45+
:param type_hint: The type hint to check
46+
:return: True if the type hint represents a list, False otherwise
47+
"""
48+
return type_hint == list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) == list)
49+
50+
51+
def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
52+
"""
53+
Merges two values into a single list.
54+
55+
If either `current` or `new` is not already a list, it is converted into one.
56+
The function ensures that both inputs are treated as lists and concatenates them.
57+
58+
If `current` is None, it is treated as an empty list.
59+
60+
:param current: The existing value(s), either a single item or a list.
61+
:param new: The new value(s) to merge, either a single item or a list.
62+
:return: A list containing elements from both `current` and `new`.
63+
"""
64+
current_list = [] if current is None else current if isinstance(current, list) else [current]
65+
new_list = new if isinstance(new, list) else [new]
66+
return current_list + new_list
67+
68+
69+
def replace_values(current: Any, new: Any) -> Any:
70+
"""
71+
Replace the `current` value with the `new` value.
72+
73+
:param current: The existing value
74+
:param new: The new value to replace
75+
:return: The new value
76+
"""
77+
return new

haystack_experimental/core/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Dict, Optional
66

77

8-
class PipelineBreakpointException(Exception):
8+
class BreakpointException(Exception):
99
"""
1010
Exception raised when a pipeline breakpoint is triggered.
1111
"""

0 commit comments

Comments
 (0)