Skip to content

Commit 8fb22ee

Browse files
mathisluckasjrl
andauthored
feat: Agent can stream ChatGenerator responses (#233)
* feat: Agent can stream ChatGenerator responses * fix: unused import * Update haystack_experimental/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <[email protected]> * add serde test --------- Co-authored-by: Sebastian Husch Lee <[email protected]>
1 parent 2c9b2c1 commit 8fb22ee

File tree

3 files changed

+182
-8
lines changed

3 files changed

+182
-8
lines changed

haystack_experimental/components/agents/agent.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from haystack.core.pipeline.base import PipelineError
1414
from haystack.core.serialization import component_from_dict
1515
from haystack.dataclasses import ChatMessage
16+
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
17+
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1618

1719
from haystack_experimental.components.tools import ToolInvoker
1820
from haystack_experimental.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
@@ -63,6 +65,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
6365
state_schema: Optional[Dict[str, Any]] = None,
6466
max_runs_per_component: int = 100,
6567
raise_on_tool_invocation_failure: bool = False,
68+
streaming_callback: Optional[SyncStreamingCallbackT] = None,
6669
):
6770
"""
6871
Initialize the agent component.
@@ -77,6 +80,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
7780
component exceeds the maximum number of runs per component.
7881
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
7982
If set to False, the exception will be turned into a chat message and passed to the LLM.
83+
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
8084
"""
8185
valid_exits = ["text"] + [tool.name for tool in tools or []]
8286
if exit_condition not in valid_exits:
@@ -92,6 +96,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
9296
self.exit_condition = exit_condition
9397
self.max_runs_per_component = max_runs_per_component
9498
self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
99+
self.streaming_callback = streaming_callback
95100

96101
output_types = {"messages": List[ChatMessage]}
97102
for param, config in self.state_schema.items():
@@ -178,6 +183,11 @@ def to_dict(self) -> Dict[str, Any]:
178183
179184
:return: Dictionary with serialized data
180185
"""
186+
if self.streaming_callback is not None:
187+
streaming_callback = serialize_callable(self.streaming_callback)
188+
else:
189+
streaming_callback = None
190+
181191
return default_to_dict(
182192
self,
183193
chat_generator=self.chat_generator.to_dict(),
@@ -187,6 +197,7 @@ def to_dict(self) -> Dict[str, Any]:
187197
state_schema=_schema_to_dict(self.state_schema),
188198
max_runs_per_component=self.max_runs_per_component,
189199
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
200+
streaming_callback=streaming_callback
190201
)
191202

192203
@classmethod
@@ -201,10 +212,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent":
201212

202213
init_params["chat_generator"] = Agent._load_component(init_params["chat_generator"])
203214

204-
# Deserialize type annotations
205215
if "state_schema" in init_params:
206216
init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
207217

218+
if init_params.get("streaming_callback") is not None:
219+
init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
220+
221+
208222
deserialize_tools_inplace(init_params, key="tools")
209223

210224
return default_from_dict(cls, data)
@@ -232,11 +246,17 @@ def _load_component(component_data: Dict[str, Any]) -> Component:
232246

233247
return instance
234248

235-
def run(self, messages: List[ChatMessage], **kwargs) -> Dict[str, Any]:
249+
def run(
250+
self,
251+
messages: List[ChatMessage],
252+
streaming_callback: Optional[SyncStreamingCallbackT] = None,
253+
**kwargs
254+
) -> Dict[str, Any]:
236255
"""
237256
Process messages and execute tools until the exit condition is met.
238257
239258
:param messages: List of chat messages to process
259+
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
240260
:param kwargs: Additional keyword arguments matching the defined input types
241261
:return: Dictionary containing messages and outputs matching the defined output types
242262
"""
@@ -245,11 +265,17 @@ def run(self, messages: List[ChatMessage], **kwargs) -> Dict[str, Any]:
245265
if self.system_prompt is not None:
246266
messages = [ChatMessage.from_system(self.system_prompt)] + messages
247267

268+
generator_inputs = {"tools": self.tools}
269+
270+
selected_callback = streaming_callback or self.streaming_callback
271+
if selected_callback is not None:
272+
generator_inputs["streaming_callback"] = selected_callback
273+
248274
result = self.pipeline.run(
249275
data={
250276
"joiner": {"value": messages},
251277
"context_joiner": {"value": state},
252-
"generator": {"tools": self.tools},
278+
"generator": generator_inputs,
253279
},
254280
include_outputs_from={"context_joiner"},
255281
)

haystack_experimental/components/writers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,4 @@
44

55
from haystack_experimental.components.writers.chat_message_writer import ChatMessageWriter
66

7-
87
_all_ = ["ChatMessageWriter"]

test/components/agents/test_agent.py

Lines changed: 153 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,30 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from datetime import datetime
6+
from typing import Iterator
7+
8+
from unittest.mock import MagicMock, patch
59
import pytest
610

11+
from openai import Stream
12+
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
13+
714
from haystack.components.builders.prompt_builder import PromptBuilder
815
from haystack.components.generators.chat.openai import OpenAIChatGenerator
9-
from haystack.utils import serialize_callable
16+
from haystack.dataclasses import ChatMessage
17+
from haystack.dataclasses.streaming_chunk import StreamingChunk
18+
from haystack.utils import serialize_callable, Secret
1019

1120
from haystack_experimental.components.agents import Agent
1221
from haystack_experimental.tools import Tool, ComponentTool
1322

1423
import os
1524

25+
26+
def streaming_callback_for_serde(chunk: StreamingChunk):
27+
pass
28+
1629
def weather_function(location):
1730
weather_info = {
1831
"Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
@@ -24,7 +37,6 @@ def weather_function(location):
2437

2538
weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
2639

27-
2840
@pytest.fixture
2941
def weather_tool():
3042
return Tool(
@@ -42,11 +54,47 @@ def component_tool():
4254
component=PromptBuilder(template="{{parrot}}")
4355
)
4456

57+
class OpenAIMockStream(Stream[ChatCompletionChunk]):
58+
def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs):
59+
client = client or MagicMock()
60+
super().__init__(client=client, *args, **kwargs)
61+
self.mock_chunk = mock_chunk
62+
63+
def __stream__(self) -> Iterator[ChatCompletionChunk]:
64+
yield self.mock_chunk
65+
66+
@pytest.fixture
67+
def openai_mock_chat_completion_chunk():
68+
"""
69+
Mock the OpenAI API completion chunk response and reuse it for tests
70+
"""
71+
72+
with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create:
73+
completion = ChatCompletionChunk(
74+
id="foo",
75+
model="gpt-4",
76+
object="chat.completion.chunk",
77+
choices=[
78+
chat_completion_chunk.Choice(
79+
finish_reason="stop",
80+
logprobs=None,
81+
index=0,
82+
delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"),
83+
)
84+
],
85+
created=int(datetime.now().timestamp()),
86+
usage=None,
87+
)
88+
mock_chat_completion_create.return_value = OpenAIMockStream(
89+
completion, cast_to=None, response=None, client=None
90+
)
91+
yield mock_chat_completion_create
92+
4593

4694
class TestAgent:
4795
def test_serde(self, weather_tool, component_tool):
48-
os.environ["OPENAI_API_KEY"] = "fake-key"
49-
generator = OpenAIChatGenerator()
96+
os.environ["FAKE_OPENAI_KEY"] = "fake-key"
97+
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
5098
agent = Agent(
5199
chat_generator=generator,
52100
tools=[weather_tool, component_tool],
@@ -58,6 +106,7 @@ def test_serde(self, weather_tool, component_tool):
58106

59107
assert serialized_agent["type"] == "haystack_experimental.components.agents.agent.Agent"
60108
assert init_parameters["chat_generator"]["type"] == "haystack.components.generators.chat.openai.OpenAIChatGenerator"
109+
assert init_parameters["streaming_callback"] == None
61110
assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function)
62111
assert init_parameters["tools"][1]["data"]["component"]["type"] == "haystack.components.builders.prompt_builder.PromptBuilder"
63112

@@ -68,4 +117,104 @@ def test_serde(self, weather_tool, component_tool):
68117
assert deserialized_agent.tools[0].function is weather_function
69118
assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder)
70119

120+
def test_serde_with_streaming_callback(self, weather_tool, component_tool):
121+
os.environ["FAKE_OPENAI_KEY"] = "fake-key"
122+
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
123+
agent = Agent(
124+
chat_generator=generator,
125+
tools=[weather_tool, component_tool],
126+
streaming_callback=streaming_callback_for_serde,
127+
)
128+
129+
serialized_agent = agent.to_dict()
130+
131+
init_parameters = serialized_agent["init_parameters"]
132+
assert init_parameters["streaming_callback"] == "test.components.agents.test_agent.streaming_callback_for_serde"
133+
134+
deserialized_agent = Agent.from_dict(serialized_agent)
135+
assert deserialized_agent.streaming_callback is streaming_callback_for_serde
136+
137+
def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
138+
chat_generator = OpenAIChatGenerator(
139+
api_key=Secret.from_token("test-api-key")
140+
)
141+
142+
streaming_callback_called = False
143+
144+
def streaming_callback(chunk: StreamingChunk) -> None:
145+
nonlocal streaming_callback_called
146+
streaming_callback_called = True
147+
148+
149+
agent = Agent(chat_generator=chat_generator, streaming_callback=streaming_callback, tools=[weather_tool])
150+
agent.warm_up()
151+
response = agent.run([ChatMessage.from_user("Hello")])
152+
153+
# check we called the streaming callback
154+
assert streaming_callback_called is True
155+
156+
# check that the component still returns the correct response
157+
assert isinstance(response, dict)
158+
assert "messages" in response
159+
assert isinstance(response["messages"], list)
160+
assert len(response["messages"]) == 2
161+
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
162+
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
163+
164+
165+
def test_run_with_run_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
166+
chat_generator = OpenAIChatGenerator(
167+
api_key=Secret.from_token("test-api-key")
168+
)
169+
170+
streaming_callback_called = False
171+
172+
def streaming_callback(chunk: StreamingChunk) -> None:
173+
nonlocal streaming_callback_called
174+
streaming_callback_called = True
175+
176+
177+
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
178+
agent.warm_up()
179+
response = agent.run([ChatMessage.from_user("Hello")], streaming_callback=streaming_callback)
180+
181+
# check we called the streaming callback
182+
assert streaming_callback_called is True
183+
184+
# check that the component still returns the correct response
185+
assert isinstance(response, dict)
186+
assert "messages" in response
187+
assert isinstance(response["messages"], list)
188+
assert len(response["messages"]) == 2
189+
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
190+
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
191+
192+
193+
def test_keep_generator_streaming(self, openai_mock_chat_completion_chunk, weather_tool):
194+
streaming_callback_called = False
195+
196+
def streaming_callback(chunk: StreamingChunk) -> None:
197+
nonlocal streaming_callback_called
198+
streaming_callback_called = True
199+
200+
chat_generator = OpenAIChatGenerator(
201+
api_key=Secret.from_token("test-api-key"),
202+
streaming_callback=streaming_callback,
203+
)
204+
205+
agent = Agent(chat_generator=chat_generator, tools=[weather_tool])
206+
agent.warm_up()
207+
response = agent.run([ChatMessage.from_user("Hello")])
208+
209+
# check we called the streaming callback
210+
assert streaming_callback_called is True
211+
212+
# check that the component still returns the correct response
213+
assert isinstance(response, dict)
214+
assert "messages" in response
215+
assert isinstance(response["messages"], list)
216+
assert len(response["messages"]) == 2
217+
assert [isinstance(reply, ChatMessage) for reply in response["messages"]]
218+
assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk
219+
71220

0 commit comments

Comments
 (0)