diff --git a/haystack_experimental/components/agents/agent.py b/haystack_experimental/components/agents/agent.py index 786cbdc9..ce6673e6 100644 --- a/haystack_experimental/components/agents/agent.py +++ b/haystack_experimental/components/agents/agent.py @@ -172,6 +172,7 @@ def _initialize_fresh_execution( requires_async: bool, *, system_prompt: Optional[str] = None, + generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[ToolsType, list[str]]] = None, **kwargs: dict[str, Any], ) -> _ExecutionContext: @@ -186,14 +187,28 @@ def _initialize_fresh_execution( When passing tool names, tools are selected from the Agent's originally configured tools. :param kwargs: Additional data to pass to the State used by the Agent. """ - exe_context = super(Agent, self)._initialize_fresh_execution( - messages=messages, - streaming_callback=streaming_callback, - requires_async=requires_async, - system_prompt=system_prompt, - tools=tools, - **kwargs, - ) + # The PR https://github.com/deepset-ai/haystack/pull/9616 added the generation_kwargs parameter to + # _initialize_fresh_execution. This change has been released in Haystack 2.20.0. + # To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly. + if inspect.signature(super(Agent, self)._initialize_fresh_execution).parameters.get("generation_kwargs"): + exe_context = super(Agent, self)._initialize_fresh_execution( + messages=messages, + streaming_callback=streaming_callback, + requires_async=requires_async, + system_prompt=system_prompt, + generation_kwargs=generation_kwargs, + tools=tools, + **kwargs, + ) + else: + exe_context = super(Agent, self)._initialize_fresh_execution( + messages=messages, + streaming_callback=streaming_callback, + requires_async=requires_async, + system_prompt=system_prompt, + tools=tools, + **kwargs, + ) # NOTE: 1st difference with parent method to add this to tool_invoker_inputs if self._tool_invoker: exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = ( @@ -213,6 +228,7 @@ def _initialize_from_snapshot( # type: ignore[override] streaming_callback: Optional[StreamingCallbackT], requires_async: bool, *, + generation_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[ToolsType, list[str]]] = None, ) -> _ExecutionContext: """ @@ -221,12 +237,26 @@ def _initialize_from_snapshot( # type: ignore[override] :param snapshot: An AgentSnapshot containing the state of a previously saved agent execution. :param streaming_callback: Optional callback for streaming responses. :param requires_async: Whether the agent run requires asynchronous execution. + :param generation_kwargs: Additional keyword arguments for chat generator. These parameters will + override the parameters passed during component initialization. :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run. When passing tool names, tools are selected from the Agent's originally configured tools. """ - exe_context = super(Agent, self)._initialize_from_snapshot( - snapshot=snapshot, streaming_callback=streaming_callback, requires_async=requires_async, tools=tools - ) + # The PR https://github.com/deepset-ai/haystack/pull/9616 added the generation_kwargs parameter to + # _initialize_from_snapshot. This change has been released in Haystack 2.20.0. + # To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly. + if inspect.signature(super(Agent, self)._initialize_from_snapshot).parameters.get("generation_kwargs"): + exe_context = super(Agent, self)._initialize_from_snapshot( + snapshot=snapshot, + streaming_callback=streaming_callback, + requires_async=requires_async, + generation_kwargs=generation_kwargs, + tools=tools, + ) + else: + exe_context = super(Agent, self)._initialize_from_snapshot( + snapshot=snapshot, streaming_callback=streaming_callback, requires_async=requires_async, tools=tools + ) # NOTE: 1st difference with parent method to add this to tool_invoker_inputs if self._tool_invoker: exe_context.tool_invoker_inputs["enable_streaming_callback_passthrough"] = ( @@ -248,6 +278,7 @@ def run( # noqa: PLR0915 messages: list[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, *, + generation_kwargs: Optional[dict[str, Any]] = None, break_point: Optional[AgentBreakpoint] = None, snapshot: Optional[AgentSnapshot] = None, # type: ignore[override] system_prompt: Optional[str] = None, @@ -260,6 +291,8 @@ def run( # noqa: PLR0915 :param messages: List of Haystack ChatMessage objects to process. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains @@ -290,13 +323,17 @@ def run( # noqa: PLR0915 # _runtime_checks. This change will be released in Haystack 2.20.0. # To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly. if len(inspect.signature(self._runtime_checks).parameters) == 2: - self._runtime_checks(break_point, snapshot) + self._runtime_checks(break_point, snapshot) # type: ignore[call-arg] # pylint: disable=too-many-function-args else: self._runtime_checks(break_point) # type: ignore[call-arg] # pylint: disable=no-value-for-parameter if snapshot: exe_context = self._initialize_from_snapshot( - snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools + snapshot=snapshot, + streaming_callback=streaming_callback, + requires_async=False, + generation_kwargs=generation_kwargs, + tools=tools, ) else: exe_context = self._initialize_fresh_execution( @@ -304,6 +341,7 @@ def run( # noqa: PLR0915 streaming_callback=streaming_callback, requires_async=False, system_prompt=system_prompt, + generation_kwargs=generation_kwargs, tools=tools, **kwargs, ) @@ -438,6 +476,7 @@ async def run_async( messages: list[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, *, + generation_kwargs: Optional[dict[str, Any]] = None, break_point: Optional[AgentBreakpoint] = None, snapshot: Optional[AgentSnapshot] = None, # type: ignore[override] system_prompt: Optional[str] = None, @@ -454,6 +493,8 @@ async def run_async( :param messages: List of Haystack ChatMessage objects to process. :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains @@ -483,13 +524,17 @@ async def run_async( # _runtime_checks. This change will be released in Haystack 2.20.0. # To maintain compatibility with Haystack 2.19 we check the number of parameters and call accordingly. if len(inspect.signature(self._runtime_checks).parameters) == 2: - self._runtime_checks(break_point, snapshot) + self._runtime_checks(break_point, snapshot) # type: ignore[call-arg] # pylint: disable=too-many-function-args else: self._runtime_checks(break_point) # type: ignore[call-arg] # pylint: disable=no-value-for-parameter if snapshot: exe_context = self._initialize_from_snapshot( - snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools + snapshot=snapshot, + streaming_callback=streaming_callback, + requires_async=True, + generation_kwargs=generation_kwargs, + tools=tools, ) else: exe_context = self._initialize_fresh_execution( @@ -497,6 +542,7 @@ async def run_async( streaming_callback=streaming_callback, requires_async=True, system_prompt=system_prompt, + generation_kwargs=generation_kwargs, tools=tools, **kwargs, )