Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 61 additions & 15 deletions haystack_experimental/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def _initialize_fresh_execution(
requires_async: bool,
*,
system_prompt: Optional[str] = None,
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[ToolsType, list[str]]] = None,
**kwargs: dict[str, Any],
) -> _ExecutionContext:
Expand All @@ -186,14 +187,28 @@ def _initialize_fresh_execution(
When passing tool names, tools are selected from the Agent's originally configured tools.
:param kwargs: Additional data to pass to the State used by the Agent.
"""
exe_context = super(Agent, self)._initialize_fresh_execution(
messages=messages,
streaming_callback=streaming_callback,
requires_async=requires_async,
system_prompt=system_prompt,
tools=tools,
**kwargs,
)
# The PR https://github.com/deepset-ai/haystack/pull/9616 added the generation_kwargs parameter to
# _initialize_fresh_execution. This change has been released in Haystack 2.20.0.
# To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
if inspect.signature(super(Agent, self)._initialize_fresh_execution).parameters.get("generation_kwargs"):
exe_context = super(Agent, self)._initialize_fresh_execution(
messages=messages,
streaming_callback=streaming_callback,
requires_async=requires_async,
system_prompt=system_prompt,
generation_kwargs=generation_kwargs,
tools=tools,
**kwargs,
)
else:
exe_context = super(Agent, self)._initialize_fresh_execution(
messages=messages,
streaming_callback=streaming_callback,
requires_async=requires_async,
system_prompt=system_prompt,
tools=tools,
**kwargs,
)
# NOTE: 1st difference with parent method to add this to tool_invoker_inputs
if self._tool_invoker:
exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
Expand All @@ -213,6 +228,7 @@ def _initialize_from_snapshot( # type: ignore[override]
streaming_callback: Optional[StreamingCallbackT],
requires_async: bool,
*,
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[ToolsType, list[str]]] = None,
) -> _ExecutionContext:
"""
Expand All @@ -221,12 +237,26 @@ def _initialize_from_snapshot( # type: ignore[override]
:param snapshot: An AgentSnapshot containing the state of a previously saved agent execution.
:param streaming_callback: Optional callback for streaming responses.
:param requires_async: Whether the agent run requires asynchronous execution.
:param generation_kwargs: Additional keyword arguments for chat generator. These parameters will
override the parameters passed during component initialization.
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
When passing tool names, tools are selected from the Agent's originally configured tools.
"""
exe_context = super(Agent, self)._initialize_from_snapshot(
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=requires_async, tools=tools
)
# The PR https://github.com/deepset-ai/haystack/pull/9616 added the generation_kwargs parameter to
# _initialize_from_snapshot. This change has been released in Haystack 2.20.0.
# To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
if inspect.signature(super(Agent, self)._initialize_from_snapshot).parameters.get("generation_kwargs"):
exe_context = super(Agent, self)._initialize_from_snapshot(
snapshot=snapshot,
streaming_callback=streaming_callback,
requires_async=requires_async,
generation_kwargs=generation_kwargs,
tools=tools,
)
else:
exe_context = super(Agent, self)._initialize_from_snapshot(
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=requires_async, tools=tools
)
# NOTE: 1st difference with parent method to add this to tool_invoker_inputs
if self._tool_invoker:
exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = (
Expand All @@ -248,6 +278,7 @@ def run( # noqa: PLR0915
messages: list[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
*,
generation_kwargs: Optional[dict[str, Any]] = None,
break_point: Optional[AgentBreakpoint] = None,
snapshot: Optional[AgentSnapshot] = None, # type: ignore[override]
system_prompt: Optional[str] = None,
Expand All @@ -260,6 +291,8 @@ def run( # noqa: PLR0915
:param messages: List of Haystack ChatMessage objects to process.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called.
:param generation_kwargs: Additional keyword arguments for LLM. These parameters will
override the parameters passed during component initialization.
:param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
for "tool_invoker".
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
Expand Down Expand Up @@ -290,20 +323,25 @@ def run( # noqa: PLR0915
# _runtime_checks. This change will be released in Haystack 2.20.0.
# To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
if len(inspect.signature(self._runtime_checks).parameters) == 2:
self._runtime_checks(break_point, snapshot)
self._runtime_checks(break_point, snapshot) # type: ignore[call-arg] # pylint: disable=too-many-function-args
else:
self._runtime_checks(break_point) # type: ignore[call-arg] # pylint: disable=no-value-for-parameter

if snapshot:
exe_context = self._initialize_from_snapshot(
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools
snapshot=snapshot,
streaming_callback=streaming_callback,
requires_async=False,
generation_kwargs=generation_kwargs,
tools=tools,
)
else:
exe_context = self._initialize_fresh_execution(
messages=messages,
streaming_callback=streaming_callback,
requires_async=False,
system_prompt=system_prompt,
generation_kwargs=generation_kwargs,
tools=tools,
**kwargs,
)
Expand Down Expand Up @@ -438,6 +476,7 @@ async def run_async(
messages: list[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
*,
generation_kwargs: Optional[dict[str, Any]] = None,
break_point: Optional[AgentBreakpoint] = None,
snapshot: Optional[AgentSnapshot] = None, # type: ignore[override]
system_prompt: Optional[str] = None,
Expand All @@ -454,6 +493,8 @@ async def run_async(
:param messages: List of Haystack ChatMessage objects to process.
:param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the
LLM. The same callback can be configured to emit tool results when a tool is called.
:param generation_kwargs: Additional keyword arguments for LLM. These parameters will
override the parameters passed during component initialization.
:param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
for "tool_invoker".
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
Expand Down Expand Up @@ -483,20 +524,25 @@ async def run_async(
# _runtime_checks. This change will be released in Haystack 2.20.0.
# To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly.
if len(inspect.signature(self._runtime_checks).parameters) == 2:
self._runtime_checks(break_point, snapshot)
self._runtime_checks(break_point, snapshot) # type: ignore[call-arg] # pylint: disable=too-many-function-args
else:
self._runtime_checks(break_point) # type: ignore[call-arg] # pylint: disable=no-value-for-parameter

if snapshot:
exe_context = self._initialize_from_snapshot(
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools
snapshot=snapshot,
streaming_callback=streaming_callback,
requires_async=True,
generation_kwargs=generation_kwargs,
tools=tools,
)
else:
exe_context = self._initialize_fresh_execution(
messages=messages,
streaming_callback=streaming_callback,
requires_async=True,
system_prompt=system_prompt,
generation_kwargs=generation_kwargs,
tools=tools,
**kwargs,
)
Expand Down
Loading