22#
33# SPDX-License-Identifier: Apache-2.0
44
5+ from datetime import datetime
6+ from typing import Iterator
7+
8+ from unittest .mock import MagicMock , patch
59import pytest
610
11+ from openai import Stream
12+ from openai .types .chat import ChatCompletionChunk , chat_completion_chunk
13+
714from haystack .components .builders .prompt_builder import PromptBuilder
815from haystack .components .generators .chat .openai import OpenAIChatGenerator
9- from haystack .utils import serialize_callable
16+ from haystack .dataclasses import ChatMessage
17+ from haystack .dataclasses .streaming_chunk import StreamingChunk
18+ from haystack .utils import serialize_callable , Secret
1019
1120from haystack_experimental .components .agents import Agent
1221from haystack_experimental .tools import Tool , ComponentTool
1322
1423import os
1524
25+
26+ def streaming_callback_for_serde (chunk : StreamingChunk ):
27+ pass
28+
1629def weather_function (location ):
1730 weather_info = {
1831 "Berlin" : {"weather" : "mostly sunny" , "temperature" : 7 , "unit" : "celsius" },
@@ -24,7 +37,6 @@ def weather_function(location):
2437
2538weather_parameters = {"type" : "object" , "properties" : {"location" : {"type" : "string" }}, "required" : ["location" ]}
2639
27-
2840@pytest .fixture
2941def weather_tool ():
3042 return Tool (
@@ -42,11 +54,47 @@ def component_tool():
4254 component = PromptBuilder (template = "{{parrot}}" )
4355 )
4456
57+ class OpenAIMockStream (Stream [ChatCompletionChunk ]):
58+ def __init__ (self , mock_chunk : ChatCompletionChunk , client = None , * args , ** kwargs ):
59+ client = client or MagicMock ()
60+ super ().__init__ (client = client , * args , ** kwargs )
61+ self .mock_chunk = mock_chunk
62+
63+ def __stream__ (self ) -> Iterator [ChatCompletionChunk ]:
64+ yield self .mock_chunk
65+
66+ @pytest .fixture
67+ def openai_mock_chat_completion_chunk ():
68+ """
69+ Mock the OpenAI API completion chunk response and reuse it for tests
70+ """
71+
72+ with patch ("openai.resources.chat.completions.Completions.create" ) as mock_chat_completion_create :
73+ completion = ChatCompletionChunk (
74+ id = "foo" ,
75+ model = "gpt-4" ,
76+ object = "chat.completion.chunk" ,
77+ choices = [
78+ chat_completion_chunk .Choice (
79+ finish_reason = "stop" ,
80+ logprobs = None ,
81+ index = 0 ,
82+ delta = chat_completion_chunk .ChoiceDelta (content = "Hello" , role = "assistant" ),
83+ )
84+ ],
85+ created = int (datetime .now ().timestamp ()),
86+ usage = None ,
87+ )
88+ mock_chat_completion_create .return_value = OpenAIMockStream (
89+ completion , cast_to = None , response = None , client = None
90+ )
91+ yield mock_chat_completion_create
92+
4593
4694class TestAgent :
4795 def test_serde (self , weather_tool , component_tool ):
48- os .environ ["OPENAI_API_KEY " ] = "fake-key"
49- generator = OpenAIChatGenerator ()
96+ os .environ ["FAKE_OPENAI_KEY " ] = "fake-key"
97+ generator = OpenAIChatGenerator (api_key = Secret . from_env_var ( "FAKE_OPENAI_KEY" ) )
5098 agent = Agent (
5199 chat_generator = generator ,
52100 tools = [weather_tool , component_tool ],
@@ -58,6 +106,7 @@ def test_serde(self, weather_tool, component_tool):
58106
59107 assert serialized_agent ["type" ] == "haystack_experimental.components.agents.agent.Agent"
60108 assert init_parameters ["chat_generator" ]["type" ] == "haystack.components.generators.chat.openai.OpenAIChatGenerator"
109+ assert init_parameters ["streaming_callback" ] == None
61110 assert init_parameters ["tools" ][0 ]["data" ]["function" ] == serialize_callable (weather_function )
62111 assert init_parameters ["tools" ][1 ]["data" ]["component" ]["type" ] == "haystack.components.builders.prompt_builder.PromptBuilder"
63112
@@ -68,4 +117,104 @@ def test_serde(self, weather_tool, component_tool):
68117 assert deserialized_agent .tools [0 ].function is weather_function
69118 assert isinstance (deserialized_agent .tools [1 ]._component , PromptBuilder )
70119
120+ def test_serde_with_streaming_callback (self , weather_tool , component_tool ):
121+ os .environ ["FAKE_OPENAI_KEY" ] = "fake-key"
122+ generator = OpenAIChatGenerator (api_key = Secret .from_env_var ("FAKE_OPENAI_KEY" ))
123+ agent = Agent (
124+ chat_generator = generator ,
125+ tools = [weather_tool , component_tool ],
126+ streaming_callback = streaming_callback_for_serde ,
127+ )
128+
129+ serialized_agent = agent .to_dict ()
130+
131+ init_parameters = serialized_agent ["init_parameters" ]
132+ assert init_parameters ["streaming_callback" ] == "test.components.agents.test_agent.streaming_callback_for_serde"
133+
134+ deserialized_agent = Agent .from_dict (serialized_agent )
135+ assert deserialized_agent .streaming_callback is streaming_callback_for_serde
136+
137+ def test_run_with_params_streaming (self , openai_mock_chat_completion_chunk , weather_tool ):
138+ chat_generator = OpenAIChatGenerator (
139+ api_key = Secret .from_token ("test-api-key" )
140+ )
141+
142+ streaming_callback_called = False
143+
144+ def streaming_callback (chunk : StreamingChunk ) -> None :
145+ nonlocal streaming_callback_called
146+ streaming_callback_called = True
147+
148+
149+ agent = Agent (chat_generator = chat_generator , streaming_callback = streaming_callback , tools = [weather_tool ])
150+ agent .warm_up ()
151+ response = agent .run ([ChatMessage .from_user ("Hello" )])
152+
153+ # check we called the streaming callback
154+ assert streaming_callback_called is True
155+
156+ # check that the component still returns the correct response
157+ assert isinstance (response , dict )
158+ assert "messages" in response
159+ assert isinstance (response ["messages" ], list )
160+ assert len (response ["messages" ]) == 2
161+ assert [isinstance (reply , ChatMessage ) for reply in response ["messages" ]]
162+ assert "Hello" in response ["messages" ][1 ].text # see openai_mock_chat_completion_chunk
163+
164+
165+ def test_run_with_run_streaming (self , openai_mock_chat_completion_chunk , weather_tool ):
166+ chat_generator = OpenAIChatGenerator (
167+ api_key = Secret .from_token ("test-api-key" )
168+ )
169+
170+ streaming_callback_called = False
171+
172+ def streaming_callback (chunk : StreamingChunk ) -> None :
173+ nonlocal streaming_callback_called
174+ streaming_callback_called = True
175+
176+
177+ agent = Agent (chat_generator = chat_generator , tools = [weather_tool ])
178+ agent .warm_up ()
179+ response = agent .run ([ChatMessage .from_user ("Hello" )], streaming_callback = streaming_callback )
180+
181+ # check we called the streaming callback
182+ assert streaming_callback_called is True
183+
184+ # check that the component still returns the correct response
185+ assert isinstance (response , dict )
186+ assert "messages" in response
187+ assert isinstance (response ["messages" ], list )
188+ assert len (response ["messages" ]) == 2
189+ assert [isinstance (reply , ChatMessage ) for reply in response ["messages" ]]
190+ assert "Hello" in response ["messages" ][1 ].text # see openai_mock_chat_completion_chunk
191+
192+
193+ def test_keep_generator_streaming (self , openai_mock_chat_completion_chunk , weather_tool ):
194+ streaming_callback_called = False
195+
196+ def streaming_callback (chunk : StreamingChunk ) -> None :
197+ nonlocal streaming_callback_called
198+ streaming_callback_called = True
199+
200+ chat_generator = OpenAIChatGenerator (
201+ api_key = Secret .from_token ("test-api-key" ),
202+ streaming_callback = streaming_callback ,
203+ )
204+
205+ agent = Agent (chat_generator = chat_generator , tools = [weather_tool ])
206+ agent .warm_up ()
207+ response = agent .run ([ChatMessage .from_user ("Hello" )])
208+
209+ # check we called the streaming callback
210+ assert streaming_callback_called is True
211+
212+ # check that the component still returns the correct response
213+ assert isinstance (response , dict )
214+ assert "messages" in response
215+ assert isinstance (response ["messages" ], list )
216+ assert len (response ["messages" ]) == 2
217+ assert [isinstance (reply , ChatMessage ) for reply in response ["messages" ]]
218+ assert "Hello" in response ["messages" ][1 ].text # see openai_mock_chat_completion_chunk
219+
71220
0 commit comments