Skip to content

Commit dfd75c1

Browse files
committed
[feature]: Add ReAct pattern to DevRel agent
1 parent 2cf8d29 commit dfd75c1

File tree

13 files changed

+387
-94
lines changed

13 files changed

+387
-94
lines changed

backend/app/agents/devrel/agent.py

Lines changed: 34 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@
55
from langchain_google_genai import ChatGoogleGenerativeAI
66
from langgraph.checkpoint.memory import InMemorySaver
77
from ..base_agent import BaseAgent, AgentState
8-
from ..classification_router import MessageCategory
98
from .tools.search_tool import TavilySearchTool
109
from .tools.faq_tool import FAQTool
10+
from .github.github_toolkit import GitHubToolkit
1111
from app.core.config import settings
1212
from .nodes.gather_context import gather_context_node
13-
from .nodes.handlers.faq import handle_faq_node
14-
from .nodes.handlers.web_search import handle_web_search_node
15-
from .nodes.handlers.technical_support import handle_technical_support_node
16-
from .nodes.handlers.onboarding import handle_onboarding_node
17-
from .generate_response_node import generate_response_node
1813
from .nodes.summarization import check_summarization_needed, summarize_conversation_node, store_summary_to_database
14+
from .nodes.react_supervisor import react_supervisor_node, supervisor_decision_router
15+
from .tool_wrappers import web_search_tool_node, faq_handler_tool_node, onboarding_tool_node, github_toolkit_tool_node
16+
from .nodes.generate_response import generate_response_node
1917

2018
logger = logging.getLogger(__name__)
2119

@@ -31,48 +29,55 @@ def __init__(self, config: Dict[str, Any] = None):
3129
)
3230
self.search_tool = TavilySearchTool()
3331
self.faq_tool = FAQTool()
32+
self.github_toolkit = GitHubToolkit()
3433
self.checkpointer = InMemorySaver()
3534
super().__init__("DevRelAgent", self.config)
3635

3736
def _build_graph(self):
3837
"""Build the DevRel agent workflow graph"""
3938
workflow = StateGraph(AgentState)
4039

41-
# Add nodes
40+
# Phase 1: Gather Context
4241
workflow.add_node("gather_context", gather_context_node)
43-
workflow.add_node("handle_faq", partial(handle_faq_node, faq_tool=self.faq_tool))
44-
workflow.add_node("handle_web_search", partial(
45-
handle_web_search_node, search_tool=self.search_tool, llm=self.llm))
46-
workflow.add_node("handle_technical_support", handle_technical_support_node)
47-
workflow.add_node("handle_onboarding", handle_onboarding_node)
42+
43+
# Phase 2: ReAct Supervisor - Decide what to do next
44+
workflow.add_node("react_supervisor", partial(react_supervisor_node, llm=self.llm))
45+
workflow.add_node("web_search_tool", partial(web_search_tool_node, search_tool=self.search_tool, llm=self.llm))
46+
workflow.add_node("faq_handler_tool", partial(faq_handler_tool_node, faq_tool=self.faq_tool))
47+
workflow.add_node("onboarding_tool", onboarding_tool_node)
48+
workflow.add_node("github_toolkit_tool", partial(github_toolkit_tool_node, github_toolkit=self.github_toolkit))
49+
50+
# Phase 3: Generate Response
4851
workflow.add_node("generate_response", partial(generate_response_node, llm=self.llm))
52+
53+
# Phase 4: Summarization
4954
workflow.add_node("check_summarization", check_summarization_needed)
5055
workflow.add_node("summarize_conversation", partial(summarize_conversation_node, llm=self.llm))
5156

52-
# Add edges
57+
# Entry point
58+
workflow.set_entry_point("gather_context")
59+
workflow.add_edge("gather_context", "react_supervisor")
60+
61+
# ReAct supervisor routing
5362
workflow.add_conditional_edges(
54-
"gather_context",
55-
self._route_to_handler,
63+
"react_supervisor",
64+
supervisor_decision_router,
5665
{
57-
MessageCategory.FAQ: "handle_faq",
58-
MessageCategory.WEB_SEARCH: "handle_web_search",
59-
MessageCategory.ONBOARDING: "handle_onboarding",
60-
MessageCategory.TECHNICAL_SUPPORT: "handle_technical_support",
61-
MessageCategory.COMMUNITY_ENGAGEMENT: "handle_technical_support",
62-
MessageCategory.DOCUMENTATION: "handle_technical_support",
63-
MessageCategory.BUG_REPORT: "handle_technical_support",
64-
MessageCategory.FEATURE_REQUEST: "handle_technical_support",
65-
MessageCategory.NOT_DEVREL: "handle_technical_support"
66+
"web_search": "web_search_tool",
67+
"faq_handler": "faq_handler_tool",
68+
"onboarding": "onboarding_tool",
69+
"github_toolkit": "github_toolkit_tool",
70+
"complete": "generate_response"
6671
}
6772
)
6873

69-
# All handlers lead to response generation
70-
for node in ["handle_faq", "handle_web_search", "handle_technical_support", "handle_onboarding"]:
71-
workflow.add_edge(node, "generate_response")
74+
# All tools return to supervisor
75+
for tool in ["web_search_tool", "faq_handler_tool", "onboarding_tool", "github_toolkit_tool"]:
76+
workflow.add_edge(tool, "react_supervisor")
7277

7378
workflow.add_edge("generate_response", "check_summarization")
7479

75-
# Conditional edge for summarization
80+
# Summarization routing
7681
workflow.add_conditional_edges(
7782
"check_summarization",
7883
self._should_summarize,
@@ -82,42 +87,11 @@ def _build_graph(self):
8287
}
8388
)
8489

85-
# End after summarization
8690
workflow.add_edge("summarize_conversation", END)
8791

88-
# Set entry point
89-
workflow.set_entry_point("gather_context")
90-
91-
# Compile with InMemorySaver checkpointer
92+
# Compile with checkpointer
9293
self.graph = workflow.compile(checkpointer=self.checkpointer)
9394

94-
def _route_to_handler(self, state: AgentState) -> str:
95-
"""Route to the appropriate handler based on intent"""
96-
classification = state.context.get("classification", {})
97-
intent = classification.get("category")
98-
99-
if isinstance(intent, str):
100-
try:
101-
intent = MessageCategory(intent.lower())
102-
except ValueError:
103-
logger.warning(f"Unknown intent string '{intent}', defaulting to TECHNICAL_SUPPORT")
104-
intent = MessageCategory.TECHNICAL_SUPPORT
105-
106-
logger.info(f"Routing based on intent: {intent} for session {state.session_id}")
107-
108-
# Mapping from MessageCategory enum to string keys used in add_conditional_edges
109-
if intent in [MessageCategory.FAQ, MessageCategory.WEB_SEARCH,
110-
MessageCategory.ONBOARDING, MessageCategory.TECHNICAL_SUPPORT,
111-
MessageCategory.COMMUNITY_ENGAGEMENT, MessageCategory.DOCUMENTATION,
112-
MessageCategory.BUG_REPORT, MessageCategory.FEATURE_REQUEST,
113-
MessageCategory.NOT_DEVREL]:
114-
logger.info(f"Routing to handler for: {intent}")
115-
return intent
116-
117-
# Later to be changed to handle anomalies
118-
logger.info(f"Unknown intent '{intent}', routing to technical support")
119-
return MessageCategory.TECHNICAL_SUPPORT
120-
12195
def _should_summarize(self, state: AgentState) -> str:
12296
"""Determine if conversation should be summarized"""
12397
if state.summarization_needed:

backend/app/agents/devrel/generate_response_node.py renamed to backend/app/agents/devrel/generate_response.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Dict, Any
33
from app.agents.state import AgentState
44
from langchain_core.messages import HumanMessage
5-
from .prompts.base_prompt import GENERAL_LLM_RESPONSE_PROMPT
5+
from .prompts.response_prompt import RESPONSE_PROMPT
66
from .nodes.handlers.web_search import create_search_response
77

88
logger = logging.getLogger(__name__)
@@ -46,7 +46,7 @@ async def _create_llm_response(state: AgentState, task_result: Dict[str, Any], l
4646
current_context_str = "\n".join(context_parts)
4747

4848
try:
49-
prompt = GENERAL_LLM_RESPONSE_PROMPT.format(
49+
prompt = RESPONSE_PROMPT.format(
5050
conversation_summary=conversation_summary,
5151
latest_message=latest_message,
5252
conversation_history=conversation_history_str,

backend/app/agents/devrel/nodes/gather_context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import logging
22
from datetime import datetime
3+
from typing import Dict, Any
34
from app.agents.state import AgentState
45

56
logger = logging.getLogger(__name__)
67

7-
async def gather_context_node(state: AgentState) -> AgentState:
8+
async def gather_context_node(state: AgentState) -> Dict[str, Any]:
89
"""Gather additional context for the user and their request"""
910
logger.info(f"Gathering context for session {state.session_id}")
1011

@@ -31,5 +32,6 @@ async def gather_context_node(state: AgentState) -> AgentState:
3132
return {
3233
"messages": [new_message],
3334
"context": updated_context,
34-
"current_task": "context_gathered"
35+
"current_task": "context_gathered",
36+
"last_interaction_time": datetime.now()
3537
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import logging
2+
import json
3+
from typing import Dict, Any
4+
from app.agents.state import AgentState
5+
from langchain_core.messages import HumanMessage
6+
from ..prompts.response_prompt import RESPONSE_PROMPT
7+
8+
logger = logging.getLogger(__name__)
9+
10+
async def generate_response_node(state: AgentState, llm) -> Dict[str, Any]:
11+
"""
12+
Final Response Generation Node
13+
"""
14+
logger.info(f"Generating response for session {state.session_id}")
15+
16+
try:
17+
final_response = await _create_response(state, llm)
18+
19+
return {
20+
"final_response": final_response,
21+
"current_task": "response_generated"
22+
}
23+
24+
except Exception as e:
25+
logger.error(f"Error generating response: {str(e)}")
26+
return {
27+
"final_response": "I apologize, but I encountered an error while generating my response. Please try asking your question again.",
28+
"errors": state.errors + [str(e)],
29+
"current_task": "response_error"
30+
}
31+
32+
async def _create_response(state: AgentState, llm) -> str:
33+
"""
34+
Response Generation and LLM synthesis
35+
"""
36+
logger.info(f"Creating response for session {state.session_id}")
37+
38+
latest_message = _get_latest_message(state)
39+
40+
conversation_summary = state.conversation_summary or "This is the beginning of our conversation."
41+
42+
recent_messages_count = min(10, len(state.messages))
43+
conversation_history = ""
44+
if state.messages:
45+
conversation_history = "\n".join([
46+
f"{msg.get('role', 'user')}: {msg.get('content', '')}"
47+
for msg in state.messages[-recent_messages_count:]
48+
])
49+
50+
if len(state.messages) > recent_messages_count:
51+
conversation_history = f"[Showing last {recent_messages_count} of {len(state.messages)} messages]\n" + \
52+
conversation_history
53+
else:
54+
conversation_history = "No previous conversation"
55+
56+
context_parts = [
57+
f"Platform: {state.platform}",
58+
f"Total interactions: {state.interaction_count}",
59+
f"Session duration: {(state.last_interaction_time - state.session_start_time).total_seconds() / 60:.1f} minutes"
60+
]
61+
62+
if state.key_topics:
63+
context_parts.append(f"Key topics discussed: {', '.join(state.key_topics)}")
64+
if state.user_profile:
65+
context_parts.append(f"User profile: {state.user_profile}")
66+
67+
current_context = "\n".join(context_parts)
68+
69+
supervisor_thinking = state.context.get("supervisor_thinking", "No reasoning process available")
70+
71+
tool_results = state.context.get("tool_results", [])
72+
tool_results_str = json.dumps(tool_results, indent=2) if tool_results else "No tool results"
73+
74+
task_result = state.task_result or {}
75+
task_result_str = json.dumps(task_result, indent=2) if task_result else "No task result"
76+
77+
try:
78+
prompt = RESPONSE_PROMPT.format(
79+
latest_message=latest_message,
80+
conversation_summary=conversation_summary,
81+
conversation_history=conversation_history,
82+
current_context=current_context,
83+
supervisor_thinking=supervisor_thinking,
84+
tool_results=tool_results_str,
85+
task_result=task_result_str
86+
)
87+
88+
logger.info(f"Generated response prompt using existing RESPONSE_PROMPT")
89+
90+
except KeyError as e:
91+
logger.error(f"Missing key in RESPONSE_PROMPT: {e}")
92+
return f"Error: Response template formatting error - {str(e)}"
93+
94+
response = await llm.ainvoke([HumanMessage(content=prompt)])
95+
return response.content.strip()
96+
97+
def _get_latest_message(state: AgentState) -> str:
98+
"""Extract the latest message from state"""
99+
if state.messages:
100+
return state.messages[-1].get("content", "")
101+
return state.context.get("original_message", "")

backend/app/agents/devrel/nodes/handlers/user_support.py

Whitespace-only changes.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import logging
2+
import json
3+
from typing import Dict, Any, Literal
4+
from app.agents.state import AgentState
5+
from langchain_core.messages import HumanMessage
6+
from ..prompts.react_prompt import REACT_SUPERVISOR_PROMPT
7+
8+
logger = logging.getLogger(__name__)
9+
10+
async def react_supervisor_node(state: AgentState, llm) -> Dict[str, Any]:
11+
"""ReAct Supervisor: Think -> Act -> Observe"""
12+
logger.info(f"ReAct Supervisor thinking for session {state.session_id}")
13+
14+
# Get current context
15+
latest_message = _get_latest_message(state)
16+
conversation_history = _get_conversation_history(state)
17+
tool_results = state.context.get("tool_results", [])
18+
iteration_count = state.context.get("iteration_count", 0)
19+
20+
prompt = REACT_SUPERVISOR_PROMPT.format(
21+
latest_message=latest_message,
22+
platform=state.platform,
23+
interaction_count=state.interaction_count,
24+
iteration_count=iteration_count,
25+
conversation_history=conversation_history,
26+
tool_results=json.dumps(tool_results, indent=2) if tool_results else "No previous tool results"
27+
)
28+
29+
response = await llm.ainvoke([HumanMessage(content=prompt)])
30+
decision = _parse_supervisor_decision(response.content)
31+
32+
logger.info(f"ReAct Supervisor decision: {decision['action']}")
33+
34+
# Update state with supervisor's thinking
35+
return {
36+
"context": {
37+
**state.context,
38+
"supervisor_thinking": response.content,
39+
"supervisor_decision": decision,
40+
"iteration_count": iteration_count + 1
41+
},
42+
"current_task": f"supervisor_decided_{decision['action']}"
43+
}
44+
45+
def _parse_supervisor_decision(response: str) -> Dict[str, Any]:
46+
"""Parse the supervisor's decision from LLM response"""
47+
try:
48+
lines = response.strip().split('\n')
49+
decision = {"action": "complete", "reasoning": "", "thinking": ""}
50+
51+
for line in lines:
52+
if line.startswith("THINK:"):
53+
decision["thinking"] = line.replace("THINK:", "").strip()
54+
elif line.startswith("ACT:"):
55+
action = line.replace("ACT:", "").strip().lower()
56+
if action in ["web_search", "faq_handler", "onboarding", "github_toolkit", "complete"]:
57+
decision["action"] = action
58+
elif line.startswith("REASON:"):
59+
decision["reasoning"] = line.replace("REASON:", "").strip()
60+
61+
return decision
62+
except Exception as e:
63+
logger.error(f"Error parsing supervisor decision: {e}")
64+
return {"action": "complete", "reasoning": "Error in decision parsing", "thinking": ""}
65+
66+
def supervisor_decision_router(state: AgentState) -> Literal["web_search", "faq_handler", "onboarding", "github_toolkit", "complete"]:
67+
"""Route based on supervisor's decision"""
68+
decision = state.context.get("supervisor_decision", {})
69+
action = decision.get("action", "complete")
70+
71+
# Safety check for infinite loops
72+
iteration_count = state.context.get("iteration_count", 0)
73+
if iteration_count > 10:
74+
logger.warning(f"Max iterations reached for session {state.session_id}")
75+
return "complete"
76+
77+
return action
78+
79+
def add_tool_result(state: AgentState, tool_name: str, result: Dict[str, Any]) -> Dict[str, Any]:
80+
"""Add tool result to state context"""
81+
tool_results = state.context.get("tool_results", [])
82+
tool_results.append({
83+
"tool": tool_name,
84+
"result": result,
85+
"iteration": state.context.get("iteration_count", 0)
86+
})
87+
88+
return {
89+
"context": {
90+
**state.context,
91+
"tool_results": tool_results
92+
},
93+
"tools_used": state.tools_used + [tool_name],
94+
"current_task": f"completed_{tool_name}"
95+
}
96+
97+
def _get_latest_message(state: AgentState) -> str:
98+
"""Extract the latest message from state"""
99+
if state.messages:
100+
return state.messages[-1].get("content", "")
101+
return state.context.get("original_message", "")
102+
103+
def _get_conversation_history(state: AgentState, max_messages: int = 5) -> str:
104+
"""Get formatted conversation history"""
105+
if not state.messages:
106+
return "No previous conversation"
107+
108+
recent_messages = state.messages[-max_messages:]
109+
return "\n".join([
110+
f"{msg.get('role', 'user')}: {msg.get('content', '')}"
111+
for msg in recent_messages
112+
])

0 commit comments

Comments
 (0)