Skip to content

Commit 097aaaf

Browse files
committed
Add mem0 integration
1 parent b11f1de commit 097aaaf

File tree

5 files changed

+450
-0
lines changed

5 files changed

+450
-0
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import sys
6+
from typing import TYPE_CHECKING
7+
8+
from lazy_imports import LazyImporter
9+
10+
_import_structure = {"agent": ["Agent"]}
11+
12+
if TYPE_CHECKING:
13+
from .agent import Agent as Agent
14+
15+
else:
16+
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Any, Optional, Union
6+
7+
from haystack import logging
8+
from haystack.components.agents.agent import Agent as HaystackAgent
9+
from haystack.components.agents.agent import _schema_from_dict
10+
from haystack.components.agents.state import replace_values
11+
from haystack.components.generators.chat.types import ChatGenerator
12+
from haystack.core.errors import PipelineRuntimeError
13+
from haystack.core.pipeline import AsyncPipeline, Pipeline
14+
from haystack.core.pipeline.breakpoint import (
15+
_create_pipeline_snapshot_from_chat_generator,
16+
_create_pipeline_snapshot_from_tool_invoker,
17+
)
18+
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
19+
from haystack.core.serialization import default_from_dict, import_class_by_name
20+
from haystack.dataclasses import ChatMessage
21+
from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot, ToolBreakpoint
22+
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
23+
from haystack.tools import Tool, Toolset, ToolsType, deserialize_tools_or_toolset_inplace
24+
from haystack.utils.callable_serialization import deserialize_callable
25+
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
26+
27+
from haystack_experimental.memory.src.mem0.memory_store import Mem0MemoryStore
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class Agent(HaystackAgent):
33+
"""
34+
A Haystack component that implements a memory-based agent.
35+
36+
:param memory_store: The memory store to use for the agent.
37+
:param user_id: The user ID for the agent.
38+
"""
39+
40+
def __init__(
41+
self,
42+
*,
43+
chat_generator: ChatGenerator,
44+
tools: Optional[ToolsType] = None,
45+
memory_store: Optional[Mem0MemoryStore] = None,
46+
system_prompt: Optional[str] = None,
47+
exit_conditions: Optional[list[str]] = None,
48+
state_schema: Optional[dict[str, Any]] = None,
49+
max_agent_steps: int = 100,
50+
streaming_callback: Optional[StreamingCallbackT] = None,
51+
raise_on_tool_invocation_failure: bool = False,
52+
tool_invoker_kwargs: Optional[dict[str, Any]] = None,
53+
) -> None:
54+
"""
55+
Initialize the agent component.
56+
57+
:param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
58+
:param tools: List of Tool objects or a Toolset that the agent can use.
59+
:param memory_store: The memory store to use for the agent.
60+
:param system_prompt: System prompt for the agent.
61+
:param exit_conditions: List of conditions that will cause the agent to return.
62+
Can include "text" if the agent should return when it generates a message without tool calls,
63+
or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
64+
:param state_schema: The schema for the runtime state used by the tools.
65+
:param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
66+
If the agent exceeds this number of steps, it will stop and return the current state.
67+
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
68+
The same callback can be configured to emit tool results when a tool is called.
69+
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
70+
If set to False, the exception will be turned into a chat message and passed to the LLM.
71+
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
72+
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
73+
:raises ValueError: If the exit_conditions are not valid.
74+
"""
75+
super(Agent, self).__init__(
76+
chat_generator=chat_generator,
77+
tools=tools,
78+
system_prompt=system_prompt,
79+
exit_conditions=exit_conditions,
80+
state_schema=state_schema,
81+
max_agent_steps=max_agent_steps,
82+
streaming_callback=streaming_callback,
83+
raise_on_tool_invocation_failure=raise_on_tool_invocation_failure,
84+
tool_invoker_kwargs=tool_invoker_kwargs,
85+
)
86+
self.memory_store = memory_store
87+
88+
def run( # noqa: PLR0915
89+
self,
90+
messages: list[ChatMessage],
91+
streaming_callback: Optional[StreamingCallbackT] = None,
92+
*,
93+
break_point: Optional[AgentBreakpoint] = None,
94+
snapshot: Optional[AgentSnapshot] = None,
95+
system_prompt: Optional[str] = None,
96+
tools: Optional[Union[ToolsType, list[str]]] = None,
97+
**kwargs: Any,
98+
) -> dict[str, Any]:
99+
"""
100+
Process messages and execute tools until an exit condition is met.
101+
102+
:param messages: List of Haystack ChatMessage objects to process.
103+
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
104+
The same callback can be configured to emit tool results when a tool is called.
105+
:param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
106+
for "tool_invoker".
107+
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
108+
the relevant information to restart the Agent execution from where it left off.
109+
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
110+
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
111+
When passing tool names, tools are selected from the Agent's originally configured tools.
112+
:param kwargs: Additional data to pass to the State schema used by the Agent.
113+
The keys must match the schema defined in the Agent's `state_schema`.
114+
:returns:
115+
A dictionary with the following keys:
116+
- "messages": List of all messages exchanged during the agent's run.
117+
- "last_message": The last message exchanged during the agent's run.
118+
- Any additional keys defined in the `state_schema`.
119+
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
120+
:raises BreakpointException: If an agent breakpoint is triggered.
121+
"""
122+
123+
agent_memory = []
124+
125+
# Retrieve memories from the memory store
126+
if self.memory_store:
127+
agent_memory = self.memory_store.search_memories(query=messages[-1].text)
128+
129+
combined_messages = messages + agent_memory
130+
131+
# We pop parent_snapshot from kwargs to avoid passing it into State.
132+
parent_snapshot = kwargs.pop("parent_snapshot", None)
133+
agent_inputs = {
134+
"messages": combined_messages,
135+
"streaming_callback": streaming_callback,
136+
"break_point": break_point,
137+
"snapshot": snapshot,
138+
**kwargs,
139+
}
140+
self._runtime_checks(break_point=break_point, snapshot=snapshot)
141+
142+
if snapshot:
143+
exe_context = self._initialize_from_snapshot(
144+
snapshot=snapshot,
145+
streaming_callback=streaming_callback,
146+
requires_async=False,
147+
tools=tools,
148+
)
149+
else:
150+
exe_context = self._initialize_fresh_execution(
151+
messages=combined_messages,
152+
streaming_callback=streaming_callback,
153+
requires_async=False,
154+
system_prompt=system_prompt,
155+
tools=tools,
156+
**kwargs,
157+
)
158+
159+
with self._create_agent_span() as span:
160+
span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
161+
162+
while exe_context.counter < self.max_agent_steps:
163+
# Handle breakpoint and ChatGenerator call
164+
Agent._check_chat_generator_breakpoint(
165+
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
166+
)
167+
# We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
168+
if exe_context.skip_chat_generator:
169+
llm_messages = exe_context.state.get("messages", [])[-1:]
170+
# Set to False so the next iteration will call the chat generator
171+
exe_context.skip_chat_generator = False
172+
else:
173+
try:
174+
result = Pipeline._run_component(
175+
component_name="chat_generator",
176+
component={"instance": self.chat_generator},
177+
inputs={
178+
"messages": exe_context.state.data["messages"],
179+
**exe_context.chat_generator_inputs,
180+
},
181+
component_visits=exe_context.component_visits,
182+
parent_span=span,
183+
)
184+
except PipelineRuntimeError as e:
185+
pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
186+
agent_name=getattr(self, "__component_name__", None),
187+
execution_context=exe_context,
188+
parent_snapshot=parent_snapshot,
189+
)
190+
e.pipeline_snapshot = pipeline_snapshot
191+
raise e
192+
193+
llm_messages = result["replies"]
194+
exe_context.state.set("messages", llm_messages)
195+
196+
# Check if any of the LLM responses contain a tool call or if the LLM is not using tools
197+
if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
198+
exe_context.counter += 1
199+
break
200+
201+
# Handle breakpoint and ToolInvoker call
202+
Agent._check_tool_invoker_breakpoint(
203+
execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
204+
)
205+
try:
206+
# We only send the messages from the LLM to the tool invoker
207+
tool_invoker_result = Pipeline._run_component(
208+
component_name="tool_invoker",
209+
component={"instance": self._tool_invoker},
210+
inputs={
211+
"messages": llm_messages,
212+
"state": exe_context.state,
213+
**exe_context.tool_invoker_inputs,
214+
},
215+
component_visits=exe_context.component_visits,
216+
parent_span=span,
217+
)
218+
except PipelineRuntimeError as e:
219+
# Access the original Tool Invoker exception
220+
original_error = e.__cause__
221+
tool_name = getattr(original_error, "tool_name", None)
222+
223+
pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
224+
tool_name=tool_name,
225+
agent_name=getattr(self, "__component_name__", None),
226+
execution_context=exe_context,
227+
parent_snapshot=parent_snapshot,
228+
)
229+
e.pipeline_snapshot = pipeline_snapshot
230+
raise e
231+
232+
tool_messages = tool_invoker_result["tool_messages"]
233+
exe_context.state = tool_invoker_result["state"]
234+
exe_context.state.set("messages", tool_messages)
235+
236+
# Check if any LLM message's tool call name matches an exit condition
237+
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
238+
exe_context.counter += 1
239+
break
240+
241+
# Increment the step counter
242+
exe_context.counter += 1
243+
244+
if exe_context.counter >= self.max_agent_steps:
245+
logger.warning(
246+
"Agent reached maximum agent steps of {max_agent_steps}, stopping.",
247+
max_agent_steps=self.max_agent_steps,
248+
)
249+
span.set_content_tag("haystack.agent.output", exe_context.state.data)
250+
span.set_tag("haystack.agent.steps_taken", exe_context.counter)
251+
252+
result = {**exe_context.state.data}
253+
if msgs := result.get("messages"):
254+
result["last_message"] = msgs[-1]
255+
256+
# Add the new conversation as memories to the memory store
257+
self.memory_store.add_memories(result["messages"])
258+
return result
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from haystack.components.generators.chat.openai import OpenAIChatGenerator
2+
from haystack.dataclasses import ChatMessage
3+
4+
from haystack_experimental.components.memory_agents.agent import Agent
5+
from haystack_experimental.memory.src.mem0.memory_store import Mem0MemoryStore
6+
7+
memory_store = Mem0MemoryStore(user_id="haystack_mem0")
8+
9+
chat_generator = OpenAIChatGenerator()
10+
agent = Agent(chat_generator=chat_generator, memory_store=memory_store)
11+
12+
answer = agent.run(messages=[ChatMessage.from_user(" suggest me some music and a drink with it to relax.")])
13+
print(answer)

haystack_experimental/memory/src/mem0/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)