Skip to content

Commit ce2bb0d

Browse files
committed
Preventing the model from looping output and repeatedly searching web pages
1 parent 119b2d9 commit ce2bb0d

File tree

2 files changed

+121
-10
lines changed

2 files changed

+121
-10
lines changed

src/core/orchestrator.py

Lines changed: 105 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Any, Optional
1212
import importlib
1313
from config.agent_prompts.base_agent_prompt import BaseAgentPrompt
14-
14+
from collections import defaultdict
1515
from omegaconf import DictConfig
1616

1717

@@ -124,6 +124,29 @@ def __init__(
124124
):
125125
self.sub_agent_llm_client.task_log = task_log
126126

127+
# Record used subtask / q / Query
128+
self.max_repeat_queries = 3
129+
self.used_queries = {}
130+
131+
def _get_query_from_tool_call(
132+
self, tool_name: str, arguments: dict
133+
) -> Optional[str]:
134+
"""
135+
Extracts the query from tool call arguments based on tool_name.
136+
Supports google_search, wiki_get_page_content, search_wiki_revision, search_archived_webpage, and scrape_website.
137+
"""
138+
if tool_name == "google_search":
139+
return "q:" + arguments.get("q")
140+
elif tool_name == "wiki_get_page_content":
141+
return "entity:" + arguments.get("entity")
142+
elif tool_name == "search_wiki_revision":
143+
return "entity:" + arguments.get("entity") + "_year:" + str(arguments.get("year")) + "_month:" + str(arguments.get("month"))
144+
elif tool_name == "search_archived_webpage":
145+
return "url:" + arguments.get("url") + "_year:" + str(arguments.get("year")) + "_month:" + str(arguments.get("month")) + "_day:" + str(arguments.get("day"))
146+
elif tool_name == "scrape_website":
147+
return "url:" + arguments.get("url")
148+
return None
149+
127150
async def _handle_llm_call_with_logging(
128151
self,
129152
system_prompt,
@@ -430,6 +453,7 @@ async def run_sub_agent(
430453
turn_count = 0
431454
all_tool_results_content_with_id = []
432455
task_failed = False # Track whether task failed
456+
should_hard_stop = False
433457

434458
while turn_count < max_turns:
435459
turn_count += 1
@@ -521,9 +545,32 @@ async def run_sub_agent(
521545

522546
call_start_time = time.time()
523547
try:
524-
tool_result = await self.sub_agent_tool_managers[
525-
sub_agent_name
526-
].execute_tool_call(server_name, tool_name, arguments)
548+
query_str = self._get_query_from_tool_call(tool_name, arguments)
549+
if query_str:
550+
cache_name = sub_agent_name + "_" + tool_name
551+
self.used_queries.setdefault(cache_name, defaultdict(lambda: [0, ""]))
552+
count = self.used_queries[cache_name][query_str][0]
553+
cache_result = self.used_queries[cache_name][query_str][1]
554+
if count > 0:
555+
tool_result = {
556+
"server_name": server_name,
557+
"tool_name": tool_name,
558+
"result": f"{cache_result}.\nNotice: This query has already been used in previous {tool_name}. Please try a different query or keyword.",
559+
}
560+
if count >= self.max_repeat_queries:
561+
should_hard_stop = True
562+
self.used_queries[cache_name][query_str][0] += 1
563+
else:
564+
tool_result = await self.sub_agent_tool_managers[
565+
sub_agent_name
566+
].execute_tool_call(server_name, tool_name, arguments)
567+
if "error" not in tool_result:
568+
self.used_queries[cache_name][query_str][1] = tool_result["result"]
569+
self.used_queries[cache_name][query_str][0] += 1
570+
else:
571+
tool_result = await self.sub_agent_tool_managers[
572+
sub_agent_name
573+
].execute_tool_call(server_name, tool_name, arguments)
527574

528575
call_end_time = time.time()
529576
call_duration_ms = int((call_end_time - call_start_time) * 1000)
@@ -603,6 +650,15 @@ async def run_sub_agent(
603650
message_history, all_tool_results_content_with_id, tool_calls_exceeded
604651
)
605652

653+
if should_hard_stop:
654+
task_failed = True
655+
self.task_log.log_step(
656+
"too_many_repeated_queries_in_sub_agent",
657+
f"{self.max_repeat_queries} repeated queries in sub agent {sub_agent_name}, stopping the task",
658+
"warning",
659+
)
660+
break
661+
606662
# Continue execution
607663
logger.debug(
608664
f"\n=== Sub Agent {sub_agent_name} Completed ({turn_count} turns) ==="
@@ -793,6 +849,7 @@ async def run_main_agent(
793849
max_turns = sys.maxsize
794850
turn_count = 0
795851
task_failed = False # Track whether task failed
852+
should_hard_stop = False
796853
while turn_count < max_turns:
797854
turn_count += 1
798855
logger.debug(f"\n--- Main Agent Turn {turn_count} ---")
@@ -877,13 +934,42 @@ async def run_main_agent(
877934
"result": sub_agent_result,
878935
}
879936
else:
880-
tool_result = (
881-
await self.main_agent_tool_manager.execute_tool_call(
882-
server_name=server_name,
883-
tool_name=tool_name,
884-
arguments=arguments,
885-
)
937+
query_str = self._get_query_from_tool_call(
938+
tool_name, arguments
886939
)
940+
if query_str:
941+
cache_name = "main_" + tool_name
942+
self.used_queries.setdefault(cache_name, defaultdict(lambda: [0, ""]))
943+
count = self.used_queries[cache_name][query_str][0]
944+
cache_result = self.used_queries[cache_name][query_str][1]
945+
if count > 0:
946+
tool_result = {
947+
"server_name": server_name,
948+
"tool_name": tool_name,
949+
"result": f"{cache_result}.\nNotice: This query has already been used in previous {tool_name}. Please try a different query or keyword.",
950+
}
951+
if count >= self.max_repeat_queries:
952+
should_hard_stop = True
953+
self.used_queries[cache_name][query_str][0] += 1
954+
else:
955+
tool_result = (
956+
await self.main_agent_tool_manager.execute_tool_call(
957+
server_name=server_name,
958+
tool_name=tool_name,
959+
arguments=arguments,
960+
)
961+
)
962+
if "error" not in tool_result:
963+
self.used_queries[cache_name][query_str][1] = tool_result["result"]
964+
self.used_queries[cache_name][query_str][0] += 1
965+
else:
966+
tool_result = (
967+
await self.main_agent_tool_manager.execute_tool_call(
968+
server_name=server_name,
969+
tool_name=tool_name,
970+
arguments=arguments,
971+
)
972+
)
887973

888974
call_end_time = time.time()
889975
call_duration_ms = int((call_end_time - call_start_time) * 1000)
@@ -959,6 +1045,15 @@ async def run_main_agent(
9591045
message_history, all_tool_results_content_with_id, tool_calls_exceeded
9601046
)
9611047

1048+
if should_hard_stop:
1049+
task_failed = True
1050+
self.task_log.log_step(
1051+
"too_many_repeated_queries",
1052+
f"{self.max_repeat_queries} repeated queries, stopping the task",
1053+
"warning",
1054+
)
1055+
break
1056+
9621057
# Record main loop end
9631058
if turn_count >= max_turns:
9641059
if (

src/llm/providers/mirothinker_sglang_client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,22 @@ async def _create_message(
141141
"LLM finish_reason is 'stop', but content is empty, triggering Error"
142142
)
143143
raise Exception("LLM finish_reason is 'stop', but content is empty")
144+
145+
# identify repeated messages and retry
146+
# Check if the last 100 characters of the response appear more than 5 times in the response content.
147+
# If so, treat it as a severe repeat and trigger a retry.
148+
resp_content = response.choices[0].message.content or ""
149+
150+
if resp_content and len(resp_content) >= 50:
151+
tail_50 = resp_content[-50:]
152+
repeat_count = resp_content.count(tail_50)
153+
if repeat_count > 5:
154+
self.task_log.log_step(
155+
"warning",
156+
"LLM | Repeat Detected",
157+
"Severe repeat: the last 50 chars appeared over 5 times, retrying...",
158+
)
159+
raise Exception("Severe repeat detected in response, please retry.")
144160

145161
logger.debug(
146162
f"LLM call finish_reason: {getattr(response.choices[0], 'finish_reason', 'N/A')}"

0 commit comments

Comments
 (0)