Skip to content

Commit df96c56

Browse files
committed
Add custom prompt
1 parent 097aaaf commit df96c56

File tree

4 files changed

+205
-51
lines changed

4 files changed

+205
-51
lines changed

haystack_experimental/components/memory_agents/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,5 +254,6 @@ def run( # noqa: PLR0915
254254
result["last_message"] = msgs[-1]
255255

256256
# Add the new conversation as memories to the memory store
257-
self.memory_store.add_memories(result["messages"])
257+
user_messages = [message for message in result["messages"] if message.role == "user"]
258+
self.memory_store.add_memories(user_messages)
258259
return result

haystack_experimental/memory/example.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,25 @@
22
from haystack.dataclasses import ChatMessage
33

44
from haystack_experimental.components.memory_agents.agent import Agent
5-
from haystack_experimental.memory.src.mem0.memory_store import Mem0MemoryStore
5+
from haystack_experimental.memory.src.mem0.memory_store import Mem0MemoryStore, Mem0Scope
66

7-
memory_store = Mem0MemoryStore(user_id="haystack_mem0")
7+
memory_store = Mem0MemoryStore(scope=Mem0Scope(user_id="haystack_mem0"))
88

9-
chat_generator = OpenAIChatGenerator()
10-
agent = Agent(chat_generator=chat_generator, memory_store=memory_store)
119

12-
answer = agent.run(messages=[ChatMessage.from_user(" suggest me some music and a drink with it to relax.")])
13-
print(answer)
10+
messages = [
11+
ChatMessage.from_user("I like to listen to Russian pop music"),
12+
ChatMessage.from_user("I liked cold spanish latte with oat milk"),
13+
ChatMessage.from_user("I live in Florence Italy and I love mountains"),
14+
ChatMessage.from_user("""I am a software engineer and I like building application in python.
15+
Most of my projects are related to NLP and LLM agents.
16+
I find it easier to use Haystack framework to build my projects."""),
17+
ChatMessage.from_user("""I work in a startup and I am the CEO of the company.
18+
I have a team of 10 people and we are building a
19+
platform for small businesses to manage their customers and sales."""),
20+
]
21+
22+
memory_store.add_memories(messages)
23+
24+
result = memory_store.search_memories()
25+
26+
print(result)

haystack_experimental/memory/src/mem0/memory_store.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,65 @@
88
from haystack.dataclasses.chat_message import ChatMessage
99
from haystack.lazy_imports import LazyImport
1010

11+
from .utils import Mem0Scope
12+
1113
with LazyImport(message="Run 'pip install mem0ai'") as mem0_import:
12-
from mem0 import MemoryClient
14+
from mem0 import Memory, MemoryClient
1315

1416

1517
class Mem0MemoryStore:
1618
"""
1719
A memory store implementation using Mem0 as the backend.
1820
19-
:param api_key: The Mem0 API key (if not provided, uses MEM0_API_KEY environment variable)
20-
:param user_id: The user ID for the memory store.
21-
:param memory_config: Configuration dictionary for Mem0 client
2221
"""
2322

24-
def __init__(self, user_id: str, api_key: Optional[str] = None, memory_config: Optional[dict[str, Any]] = None):
23+
def __init__(
24+
self,
25+
scope: Mem0Scope,
26+
api_key: Optional[str] = None,
27+
memory_config: Optional[dict[str, Any]] = None,
28+
search_criteria: Optional[dict[str, Any]] = None,
29+
):
30+
"""
31+
Initialize the Mem0 memory store.
32+
33+
:param scope: The scope for the memory store. This defines the identifiers to retrieve or update memories.
34+
:param api_key: The Mem0 API key (if not provided, uses MEM0_API_KEY environment variable)
35+
:param memory_config: Configuration dictionary for Mem0 client
36+
:param search_criteria: Set the search configuration for the memory store.
37+
This can include query, filters, and top_k.
38+
"""
39+
2540
mem0_import.check()
2641
self.api_key = api_key or os.getenv("MEM0_API_KEY")
2742
if not self.api_key:
2843
raise ValueError("Mem0 API key must be provided either as parameter or MEM0_API_KEY environment variable.")
2944

30-
self.user_id = user_id
45+
self.scope = scope
3146
self.memory_config = memory_config
3247

3348
# If a memory config is provided, use it to initialize the Mem0 client
3449
if self.memory_config:
35-
self.client = MemoryClient.from_config(self.memory_config)
50+
self.client = Memory.from_config(self.memory_config)
3651
else:
37-
self.client = MemoryClient(api_key=self.api_key)
52+
self.client = MemoryClient(
53+
api_key=self.api_key,
54+
)
3855

3956
# User can set the search criteria using the set_search_criteria method
40-
self.search_criteria = {}
57+
self.search_criteria = search_criteria
58+
if not self.search_criteria:
59+
self.search_criteria = {
60+
"query": None,
61+
"filters": None,
62+
"top_k": 10,
63+
}
4164

4265
def to_dict(self) -> dict[str, Any]:
4366
"""Serialize the store configuration to a dictionary."""
4467
return default_to_dict(
4568
self,
46-
user_id=self.user_id,
69+
scope=self.scope,
4770
api_key=self.api_key,
4871
memory_config=self.memory_config,
4972
search_criteria=self.search_criteria,
@@ -54,7 +77,7 @@ def from_dict(cls, data: dict[str, Any]) -> "Mem0MemoryStore":
5477
"""Deserialize the store from a dictionary."""
5578
return default_from_dict(cls, data)
5679

57-
def add_memories(self, messages: list[ChatMessage]) -> list[str]:
80+
def add_memories(self, messages: list[ChatMessage], infer: bool = True) -> list[str]:
5881
"""
5982
Add ChatMessage memories to Mem0.
6083
@@ -69,53 +92,52 @@ def add_memories(self, messages: list[ChatMessage]) -> list[str]:
6992
mem0_message = [{"role": "user", "content": message.text}]
7093

7194
try:
72-
# Mem0 primarily uses user_id as the main identifier
73-
# org_id and session_id are stored in metadata for filtering
74-
result = self.client.add(messages=mem0_message, user_id=self.user_id, metadata=message.meta)
95+
result = self.client.add(
96+
messages=mem0_message, metadata=message.meta, infer=infer, **self.scope.get_scope()
97+
)
7598
# Mem0 returns different response formats, handle both
7699
memory_id = result.get("id") or result.get("memory_id") or str(result)
77100
added_ids.append(memory_id)
78101
except Exception as e:
79102
raise RuntimeError(f"Failed to add memory message: {e}") from e
80103

81-
return added_ids
82-
83-
def set_search_criteria(
84-
self, query: Optional[str] = None, filters: Optional[dict[str, Any]] = None, top_k: Optional[int] = None
85-
):
86-
"""
87-
Set the memory configuration for the memory store.
88-
"""
89-
self.search_criteria = {"query": query, "filters": filters, "top_k": top_k}
104+
return list(added_ids)
90105

91106
def search_memories(
92-
self, query: Optional[str] = None, filters: Optional[dict[str, Any]] = None, top_k: int = 8
107+
self,
108+
query: Optional[str] = None,
109+
filters: Optional[dict[str, Any]] = None,
110+
top_k: int = 5,
111+
search_criteria: Optional[dict[str, Any]] = None,
93112
) -> list[ChatMessage]:
94113
"""
95114
Search for memories in Mem0.
96115
97116
:param query: Text query to search for. If not provided, all memories will be returned.
98117
:param filters: Additional filters to apply on search. For more details on mem0 filters, see https://mem0.ai/docs/search/
99118
:param top_k: Maximum number of results to return
119+
:param search_criteria: Search criteria to search memories from the store.
120+
This can include query, filters, and top_k.
100121
:returns: List of ChatMessage memories matching the criteria
101122
"""
102123
# Prepare filters for Mem0
124+
search_criteria = search_criteria or self.search_criteria
103125

104-
search_query = query or self.search_criteria.get("query", None)
105-
search_filters = filters or self.search_criteria.get("filters", {})
106-
search_top_k = top_k or self.search_criteria.get("top_k", 10)
126+
search_query = query or search_criteria.get("query", None)
127+
search_filters = filters or search_criteria.get("filters", {})
128+
search_top_k = top_k or search_criteria.get("top_k", 10)
107129

108130
if search_filters:
109-
mem0_filters = {"AND": [{"user_id": self.user_id}, search_filters]}
131+
mem0_filters = search_filters
110132
else:
111-
mem0_filters = {"user_id": self.user_id}
133+
mem0_filters = self.scope.get_scope()
112134

113135
try:
114136
if not search_query:
115-
memories = self.client.get_all(filters=mem0_filters, top_k=search_top_k)
137+
memories = self.client.get_all(filters=mem0_filters, **self.scope.get_scope())
116138
else:
117139
memories = self.client.search(
118-
query=search_query, limit=search_top_k, filters=mem0_filters, user_id=self.user_id
140+
query=search_query, limit=search_top_k, filters=mem0_filters, **self.scope.get_scope()
119141
)
120142
messages = [
121143
ChatMessage.from_user(text=memory["memory"], meta=memory["metadata"]) for memory in memories["results"]
@@ -135,7 +157,7 @@ def delete_all_memories(self, user_id: Optional[str] = None):
135157
:param user_id: User identifier for scoping the deletion
136158
"""
137159
try:
138-
self.client.delete_all(user_id=user_id or self.user_id)
160+
self.client.delete_all(**self.scope.get_scope())
139161
except Exception as e:
140162
raise RuntimeError(f"Failed to delete memories for user {user_id}: {e}") from e
141163

@@ -149,15 +171,3 @@ def delete_memory(self, memory_id: str):
149171
self.client.delete(memory_id=memory_id)
150172
except Exception as e:
151173
raise RuntimeError(f"Failed to delete memory {memory_id}: {e}") from e
152-
153-
def get_memory(self, memory_id: str) -> ChatMessage:
154-
"""
155-
Get memory from Mem0.
156-
157-
:param memory_id: The ID of the memory to get.
158-
:returns: The memory.
159-
"""
160-
try:
161-
return self.client.get(memory_id=memory_id)
162-
except Exception as e:
163-
raise RuntimeError(f"Failed to get memory {memory_id}: {e}") from e
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from typing import Any, Optional
2+
3+
from haystack import default_from_dict
4+
5+
6+
class Mem0Scope:
7+
"""
8+
This object is used to scope the memory store.
9+
"""
10+
11+
def __init__(self, user_id: Optional[str] = None, run_id: Optional[str] = None, session_id: Optional[str] = None):
12+
self.user_id = user_id
13+
self.run_id = run_id
14+
self.session_id = session_id
15+
16+
self._check_one_id_is_set()
17+
18+
def _check_one_id_is_set(self):
19+
# we need to check that at least one of the ids is set
20+
if not self.user_id and not self.run_id and not self.session_id:
21+
raise ValueError("At least one of user_id, run_id, or session_id must be set")
22+
23+
def _to_dict(self) -> dict[str, Any]:
24+
return {
25+
"user_id": self.user_id,
26+
"run_id": self.run_id,
27+
"session_id": self.session_id,
28+
}
29+
30+
@classmethod
31+
def _from_dict(cls, data: dict[str, Any]) -> "Mem0Scope":
32+
return default_from_dict(cls, data)
33+
34+
def _get_scope(self) -> dict[str, Any]:
35+
return {key: value for key, value in self.to_dict().items() if value is not None}
36+
37+
38+
SEMANTIC_MEMORY_EXTRACTION_PROMPT = """
39+
You are extracting semantic memories from a conversation chain. Semantic memories are general
40+
preferences or personality traits or facts about the user that occurred during the conversation.
41+
42+
43+
WHAT TO EXTRACT:
44+
- General preferences or personality traits
45+
- Facts about the user
46+
47+
48+
WHAT NOT TO EXTRACT:
49+
- Vague statements without concrete facts
50+
- Specific events or actions that occurred during the conversation
51+
- Contextual details from the conversation flow
52+
- Problem-solution pairs that emerged during the conversation
53+
- Questions without answers
54+
- Speculative or hypothetical statements
55+
56+
57+
OUTPUT FORMAT:
58+
You MUST return a JSON object with ONLY a "facts" key containing an array of strings.
59+
Each fact should be a string that contains fact about the user.
60+
61+
{
62+
"facts": [
63+
"User likes to listen to Russian pop music"
64+
]
65+
}
66+
67+
IMPORTANT:
68+
- Return JSON with ONLY the "facts" key
69+
- Each fact must be a STRING, not an object
70+
- If no semantic memories can be extracted, return {"facts": []}
71+
72+
EXAMPLES:
73+
74+
Example 1:
75+
Conversation:
76+
USER: I like to listen to Russian pop music
77+
ASSISTANT: That's interesting! What other types of music do you like?
78+
USER: I also like to listen to jazz music
79+
ASSISTANT: Jazz is great! Do you have a favorite jazz artist?
80+
USER: Yes, I like to listen to John Coltrane
81+
82+
Extracted semantic memory:
83+
{
84+
"facts": [
85+
"User likes to listen to Russian pop music"
86+
"User likes to listen to jazz music"
87+
"User likes to listen to John Coltrane"
88+
]
89+
}
90+
91+
Example 2:
92+
Conversation:
93+
USER: I live in Florence Italy and I love mountains.
94+
ASSISTANT: That sounds exciting! Which countries are you planning to visit?
95+
USER: I'm thinking of France, Italy, and Spain. I've already booked flights to Paris for next month.
96+
ASSISTANT: Great itinerary! Paris is beautiful. Have you been to Europe before?
97+
USER: Yes, I visited London last year for a conference. It was my first time in Europe.
98+
99+
Extracted semantic memory:
100+
{
101+
"facts": [
102+
"User lives in Florence Italy and loves mountains"
103+
]
104+
}
105+
106+
Example 3 (Empty Output):
107+
Conversation:
108+
USER: Hello, how are you?
109+
ASSISTANT: I'm doing well, thank you! How can I help you today?
110+
USER: Just checking in, nothing specific.
111+
112+
Extracted semantic memory:
113+
{
114+
"facts": []
115+
}
116+
117+
118+
"""
119+
120+
semantic_memory_config = {
121+
"llm": {
122+
"provider": "openai",
123+
"config": {
124+
"model": "gpt-4.1-nano-2025-04-14",
125+
"temperature": 0.2,
126+
"max_tokens": 2000,
127+
},
128+
},
129+
"custom_fact_extraction_prompt": SEMANTIC_MEMORY_EXTRACTION_PROMPT,
130+
}

0 commit comments

Comments
 (0)