|
11 | 11 | from typing import Any, Optional |
12 | 12 | import importlib |
13 | 13 | from config.agent_prompts.base_agent_prompt import BaseAgentPrompt |
14 | | - |
| 14 | +from collections import defaultdict |
15 | 15 | from omegaconf import DictConfig |
16 | 16 |
|
17 | 17 |
|
@@ -124,6 +124,29 @@ def __init__( |
124 | 124 | ): |
125 | 125 | self.sub_agent_llm_client.task_log = task_log |
126 | 126 |
|
| 127 | + # Record used subtask / q / Query |
| 128 | + self.max_repeat_queries = 3 |
| 129 | + self.used_queries = {} |
| 130 | + |
| 131 | + def _get_query_from_tool_call( |
| 132 | + self, tool_name: str, arguments: dict |
| 133 | + ) -> Optional[str]: |
| 134 | + """ |
| 135 | + Extracts the query from tool call arguments based on tool_name. |
| 136 | + Supports google_search, wiki_get_page_content, search_wiki_revision, search_archived_webpage, and scrape_website. |
| 137 | + """ |
| 138 | + if tool_name == "google_search": |
| 139 | + return "q:" + arguments.get("q") |
| 140 | + elif tool_name == "wiki_get_page_content": |
| 141 | + return "entity:" + arguments.get("entity") |
| 142 | + elif tool_name == "search_wiki_revision": |
| 143 | + return "entity:" + arguments.get("entity") + "_year:" + str(arguments.get("year")) + "_month:" + str(arguments.get("month")) |
| 144 | + elif tool_name == "search_archived_webpage": |
| 145 | + return "url:" + arguments.get("url") + "_year:" + str(arguments.get("year")) + "_month:" + str(arguments.get("month")) + "_day:" + str(arguments.get("day")) |
| 146 | + elif tool_name == "scrape_website": |
| 147 | + return "url:" + arguments.get("url") |
| 148 | + return None |
| 149 | + |
127 | 150 | async def _handle_llm_call_with_logging( |
128 | 151 | self, |
129 | 152 | system_prompt, |
@@ -430,6 +453,7 @@ async def run_sub_agent( |
430 | 453 | turn_count = 0 |
431 | 454 | all_tool_results_content_with_id = [] |
432 | 455 | task_failed = False # Track whether task failed |
| 456 | + should_hard_stop = False |
433 | 457 |
|
434 | 458 | while turn_count < max_turns: |
435 | 459 | turn_count += 1 |
@@ -521,9 +545,32 @@ async def run_sub_agent( |
521 | 545 |
|
522 | 546 | call_start_time = time.time() |
523 | 547 | try: |
524 | | - tool_result = await self.sub_agent_tool_managers[ |
525 | | - sub_agent_name |
526 | | - ].execute_tool_call(server_name, tool_name, arguments) |
| 548 | + query_str = self._get_query_from_tool_call(tool_name, arguments) |
| 549 | + if query_str: |
| 550 | + cache_name = sub_agent_name + "_" + tool_name |
| 551 | + self.used_queries.setdefault(cache_name, defaultdict(lambda: [0, ""])) |
| 552 | + count = self.used_queries[cache_name][query_str][0] |
| 553 | + cache_result = self.used_queries[cache_name][query_str][1] |
| 554 | + if count > 0: |
| 555 | + tool_result = { |
| 556 | + "server_name": server_name, |
| 557 | + "tool_name": tool_name, |
| 558 | + "result": f"{cache_result}.\nNotice: This query has already been used in previous {tool_name}. Please try a different query or keyword.", |
| 559 | + } |
| 560 | + if count >= self.max_repeat_queries: |
| 561 | + should_hard_stop = True |
| 562 | + self.used_queries[cache_name][query_str][0] += 1 |
| 563 | + else: |
| 564 | + tool_result = await self.sub_agent_tool_managers[ |
| 565 | + sub_agent_name |
| 566 | + ].execute_tool_call(server_name, tool_name, arguments) |
| 567 | + if "error" not in tool_result: |
| 568 | + self.used_queries[cache_name][query_str][1] = tool_result["result"] |
| 569 | + self.used_queries[cache_name][query_str][0] += 1 |
| 570 | + else: |
| 571 | + tool_result = await self.sub_agent_tool_managers[ |
| 572 | + sub_agent_name |
| 573 | + ].execute_tool_call(server_name, tool_name, arguments) |
527 | 574 |
|
528 | 575 | call_end_time = time.time() |
529 | 576 | call_duration_ms = int((call_end_time - call_start_time) * 1000) |
@@ -603,6 +650,15 @@ async def run_sub_agent( |
603 | 650 | message_history, all_tool_results_content_with_id, tool_calls_exceeded |
604 | 651 | ) |
605 | 652 |
|
| 653 | + if should_hard_stop: |
| 654 | + task_failed = True |
| 655 | + self.task_log.log_step( |
| 656 | + "too_many_repeated_queries_in_sub_agent", |
| 657 | + f"{self.max_repeat_queries} repeated queries in sub agent {sub_agent_name}, stopping the task", |
| 658 | + "warning", |
| 659 | + ) |
| 660 | + break |
| 661 | + |
606 | 662 | # Continue execution |
607 | 663 | logger.debug( |
608 | 664 | f"\n=== Sub Agent {sub_agent_name} Completed ({turn_count} turns) ===" |
@@ -793,6 +849,7 @@ async def run_main_agent( |
793 | 849 | max_turns = sys.maxsize |
794 | 850 | turn_count = 0 |
795 | 851 | task_failed = False # Track whether task failed |
| 852 | + should_hard_stop = False |
796 | 853 | while turn_count < max_turns: |
797 | 854 | turn_count += 1 |
798 | 855 | logger.debug(f"\n--- Main Agent Turn {turn_count} ---") |
@@ -877,13 +934,42 @@ async def run_main_agent( |
877 | 934 | "result": sub_agent_result, |
878 | 935 | } |
879 | 936 | else: |
880 | | - tool_result = ( |
881 | | - await self.main_agent_tool_manager.execute_tool_call( |
882 | | - server_name=server_name, |
883 | | - tool_name=tool_name, |
884 | | - arguments=arguments, |
885 | | - ) |
| 937 | + query_str = self._get_query_from_tool_call( |
| 938 | + tool_name, arguments |
886 | 939 | ) |
| 940 | + if query_str: |
| 941 | + cache_name = "main_" + tool_name |
| 942 | + self.used_queries.setdefault(cache_name, defaultdict(lambda: [0, ""])) |
| 943 | + count = self.used_queries[cache_name][query_str][0] |
| 944 | + cache_result = self.used_queries[cache_name][query_str][1] |
| 945 | + if count > 0: |
| 946 | + tool_result = { |
| 947 | + "server_name": server_name, |
| 948 | + "tool_name": tool_name, |
| 949 | + "result": f"{cache_result}.\nNotice: This query has already been used in previous {tool_name}. Please try a different query or keyword.", |
| 950 | + } |
| 951 | + if count >= self.max_repeat_queries: |
| 952 | + should_hard_stop = True |
| 953 | + self.used_queries[cache_name][query_str][0] += 1 |
| 954 | + else: |
| 955 | + tool_result = ( |
| 956 | + await self.main_agent_tool_manager.execute_tool_call( |
| 957 | + server_name=server_name, |
| 958 | + tool_name=tool_name, |
| 959 | + arguments=arguments, |
| 960 | + ) |
| 961 | + ) |
| 962 | + if "error" not in tool_result: |
| 963 | + self.used_queries[cache_name][query_str][1] = tool_result["result"] |
| 964 | + self.used_queries[cache_name][query_str][0] += 1 |
| 965 | + else: |
| 966 | + tool_result = ( |
| 967 | + await self.main_agent_tool_manager.execute_tool_call( |
| 968 | + server_name=server_name, |
| 969 | + tool_name=tool_name, |
| 970 | + arguments=arguments, |
| 971 | + ) |
| 972 | + ) |
887 | 973 |
|
888 | 974 | call_end_time = time.time() |
889 | 975 | call_duration_ms = int((call_end_time - call_start_time) * 1000) |
@@ -959,6 +1045,15 @@ async def run_main_agent( |
959 | 1045 | message_history, all_tool_results_content_with_id, tool_calls_exceeded |
960 | 1046 | ) |
961 | 1047 |
|
| 1048 | + if should_hard_stop: |
| 1049 | + task_failed = True |
| 1050 | + self.task_log.log_step( |
| 1051 | + "too_many_repeated_queries", |
| 1052 | + f"{self.max_repeat_queries} repeated queries, stopping the task", |
| 1053 | + "warning", |
| 1054 | + ) |
| 1055 | + break |
| 1056 | + |
962 | 1057 | # Record main loop end |
963 | 1058 | if turn_count >= max_turns: |
964 | 1059 | if ( |
|
0 commit comments