Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 115 additions & 27 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Any, Callable, Literal
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Callable

from litellm import ContextWindowExceededError

Expand Down Expand Up @@ -53,10 +55,11 @@ def get_weather(city: str) -> str:
[
f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.",
f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n",
"To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.",
"After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n",
"To do this, you will interleave next_thought and next_tool_calls in each turn, and also when finishing the task.",
"You can call multiple tools in parallel by providing multiple tool calls in next_tool_calls.",
"After each set of tool calls, you receive resulting observations, which get appended to your trajectory.\n",
"When writing next_thought, you may reason about the current situation and plan for future steps.",
"When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n",
"When selecting next_tool_calls, each tool must be one of:\n",
]
)

Expand All @@ -69,14 +72,16 @@ def get_weather(city: str) -> str:

for idx, tool in enumerate(tools.values()):
instr.append(f"({idx + 1}) {tool}")
instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format")
instr.append(
"When providing `next_tool_calls`, provide a list of tool calls. Each tool call should be a dictionary with 'name' and 'args' keys. "
"The 'name' must be one of the tool names listed above, and 'args' must be a dictionary in JSON format containing the arguments for that tool."
)

react_signature = (
dspy.Signature({**signature.input_fields}, "\n".join(instr))
.append("trajectory", dspy.InputField(), type_=str)
.append("next_thought", dspy.OutputField(), type_=str)
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())])
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
.append("next_tool_calls", dspy.OutputField(), type_=list[dict[str, Any]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot can we use ToolCalls class to define the schema explicitly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Changed the type annotation from list[dict[str, Any]] to ToolCalls and updated the _parse_tool_calls method to handle ToolCalls objects. All tests passing. (commit e0833e3)

)

fallback_signature = dspy.Signature(
Expand Down Expand Up @@ -104,15 +109,26 @@ def forward(self, **input_args):
break

trajectory[f"thought_{idx}"] = pred.next_thought
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
except Exception as err:
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"

if pred.next_tool_name == "finish":
# Parse tool calls - handle both list format and backward compatibility
tool_calls = self._parse_tool_calls(pred.next_tool_calls)
trajectory[f"tool_calls_{idx}"] = tool_calls

# Execute tools in parallel
observations = self._execute_tools_parallel(tool_calls)

# Store observations as a structured format that includes tool names
# This makes it easier for the LLM to understand which observation corresponds to which tool
formatted_observations = []
for tool_call, observation in zip(tool_calls, observations, strict=True):
formatted_observations.append({
"tool": tool_call["name"],
"result": observation
})
trajectory[f"observations_{idx}"] = formatted_observations

# Check if any tool call is "finish"
if any(tc["name"] == "finish" for tc in tool_calls):
break

extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
Expand All @@ -129,20 +145,92 @@ async def aforward(self, **input_args):
break

trajectory[f"thought_{idx}"] = pred.next_thought
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args)
except Exception as err:
trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"

if pred.next_tool_name == "finish":
# Parse tool calls - handle both list format and backward compatibility
tool_calls = self._parse_tool_calls(pred.next_tool_calls)
trajectory[f"tool_calls_{idx}"] = tool_calls

# Execute tools in parallel
observations = await self._execute_tools_parallel_async(tool_calls)

# Store observations as a structured format that includes tool names
# This makes it easier for the LLM to understand which observation corresponds to which tool
formatted_observations = []
for tool_call, observation in zip(tool_calls, observations, strict=True):
formatted_observations.append({
"tool": tool_call["name"],
"result": observation
})
trajectory[f"observations_{idx}"] = formatted_observations

# Check if any tool call is "finish"
if any(tc["name"] == "finish" for tc in tool_calls):
break

extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
return dspy.Prediction(trajectory=trajectory, **extract)

def _parse_tool_calls(self, tool_calls_data):
"""Parse tool calls from the prediction output.

Handles both the new list format and provides backward compatibility.
"""
# If it's already a list of dicts with 'name' and 'args', use it directly
if isinstance(tool_calls_data, list):
return tool_calls_data

# Handle single dict case (shouldn't normally happen but for robustness)
if isinstance(tool_calls_data, dict) and "name" in tool_calls_data and "args" in tool_calls_data:
return [tool_calls_data]

# If we got something unexpected, raise an error
raise ValueError(f"Invalid tool_calls format: {tool_calls_data}")

def _execute_tools_parallel(self, tool_calls: list[dict[str, Any]]) -> list[Any]:
"""Execute multiple tools in parallel using ThreadPoolExecutor.

Args:
tool_calls: List of tool call dicts, each with 'name' and 'args' keys

Returns:
List of observations in the same order as tool_calls
"""
def execute_single_tool(tool_call: dict[str, Any]) -> Any:
tool_name = tool_call["name"]
tool_args = tool_call.get("args", {})
try:
return self.tools[tool_name](**tool_args)
except Exception as err:
return f"Execution error in {tool_name}: {_fmt_exc(err)}"

# Execute tools in parallel using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
observations = list(executor.map(execute_single_tool, tool_calls))

return observations

async def _execute_tools_parallel_async(self, tool_calls: list[dict[str, Any]]) -> list[Any]:
"""Execute multiple tools in parallel using asyncio.gather.

Args:
tool_calls: List of tool call dicts, each with 'name' and 'args' keys

Returns:
List of observations in the same order as tool_calls
"""
async def execute_single_tool(tool_call: dict[str, Any]) -> Any:
tool_name = tool_call["name"]
tool_args = tool_call.get("args", {})
try:
return await self.tools[tool_name].acall(**tool_args)
except Exception as err:
return f"Execution error in {tool_name}: {_fmt_exc(err)}"

# Execute tools in parallel using asyncio.gather
observations = await asyncio.gather(*[execute_single_tool(tc) for tc in tool_calls])

return observations

def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args):
for _ in range(3):
try:
Expand Down Expand Up @@ -171,14 +259,14 @@ def truncate_trajectory(self, trajectory):
Users can override this method to implement their own truncation logic.
"""
keys = list(trajectory.keys())
if len(keys) < 4:
# Every tool call has 4 keys: thought, tool_name, tool_args, and observation.
if len(keys) < 3:
# Every iteration has 3 keys: thought, tool_calls, and observations.
raise ValueError(
"The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be "
"truncated because it only has one tool call."
"truncated because it only has one iteration."
)

for key in keys[:4]:
for key in keys[:3]:
trajectory.pop(key)

return trajectory
Expand Down
Loading