diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 5f87879f80..631828201b 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -1,10 +1,12 @@ +import asyncio import logging -from typing import TYPE_CHECKING, Any, Callable, Literal +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, Callable from litellm import ContextWindowExceededError import dspy -from dspy.adapters.types.tool import Tool +from dspy.adapters.types.tool import Tool, ToolCalls from dspy.primitives.module import Module from dspy.signatures.signature import ensure_signature @@ -53,10 +55,11 @@ def get_weather(city: str) -> str: [ f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", - "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", - "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", + "To do this, you will interleave next_thought and next_tool_calls in each turn, and also when finishing the task.", + "You can call multiple tools in parallel by providing multiple tool calls in next_tool_calls.", + "After each set of tool calls, you receive resulting observations, which get appended to your trajectory.\n", "When writing next_thought, you may reason about the current situation and plan for future steps.", - "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", + "When selecting next_tool_calls, each tool must be one of:\n", ] ) @@ -69,14 +72,16 @@ def get_weather(city: str) -> str: for idx, tool in enumerate(tools.values()): instr.append(f"({idx + 1}) {tool}") - instr.append("When providing `next_tool_args`, the value inside the field must be in JSON format") + instr.append( + "When providing `next_tool_calls`, provide a list of tool calls. Each tool call should be a dictionary with 'name' and 'args' keys. " + "The 'name' must be one of the tool names listed above, and 'args' must be a dictionary in JSON format containing the arguments for that tool." + ) react_signature = ( dspy.Signature({**signature.input_fields}, "\n".join(instr)) .append("trajectory", dspy.InputField(), type_=str) .append("next_thought", dspy.OutputField(), type_=str) - .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) - .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) + .append("next_tool_calls", dspy.OutputField(), type_=ToolCalls) ) fallback_signature = dspy.Signature( @@ -104,15 +109,26 @@ def forward(self, **input_args): break trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - - try: - trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" - if pred.next_tool_name == "finish": + # Parse tool calls - handle both list format and backward compatibility + tool_calls = self._parse_tool_calls(pred.next_tool_calls) + trajectory[f"tool_calls_{idx}"] = tool_calls + + # Execute tools in parallel + observations = self._execute_tools_parallel(tool_calls) + + # Store observations as a structured format that includes tool names + # This makes it easier for the LLM to understand which observation corresponds to which tool + formatted_observations = [] + for tool_call, observation in zip(tool_calls, observations, strict=True): + formatted_observations.append({ + "tool": tool_call["name"], + "result": observation + }) + trajectory[f"observations_{idx}"] = formatted_observations + + # Check if any tool call is "finish" + if any(tc["name"] == "finish" for tc in tool_calls): break extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) @@ -129,20 +145,96 @@ async def aforward(self, **input_args): break trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - try: - trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" - - if pred.next_tool_name == "finish": + # Parse tool calls - handle both list format and backward compatibility + tool_calls = self._parse_tool_calls(pred.next_tool_calls) + trajectory[f"tool_calls_{idx}"] = tool_calls + + # Execute tools in parallel + observations = await self._execute_tools_parallel_async(tool_calls) + + # Store observations as a structured format that includes tool names + # This makes it easier for the LLM to understand which observation corresponds to which tool + formatted_observations = [] + for tool_call, observation in zip(tool_calls, observations, strict=True): + formatted_observations.append({ + "tool": tool_call["name"], + "result": observation + }) + trajectory[f"observations_{idx}"] = formatted_observations + + # Check if any tool call is "finish" + if any(tc["name"] == "finish" for tc in tool_calls): break extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args) return dspy.Prediction(trajectory=trajectory, **extract) + def _parse_tool_calls(self, tool_calls_data): + """Parse tool calls from the prediction output. + + Handles both ToolCalls objects and list formats for backward compatibility. + """ + # If it's a ToolCalls object, extract the list of tool calls + if isinstance(tool_calls_data, ToolCalls): + return [{"name": tc.name, "args": tc.args} for tc in tool_calls_data.tool_calls] + + # If it's already a list of dicts with 'name' and 'args', use it directly + if isinstance(tool_calls_data, list): + return tool_calls_data + + # Handle single dict case (shouldn't normally happen but for robustness) + if isinstance(tool_calls_data, dict) and "name" in tool_calls_data and "args" in tool_calls_data: + return [tool_calls_data] + + # If we got something unexpected, raise an error + raise ValueError(f"Invalid tool_calls format: {tool_calls_data}") + + def _execute_tools_parallel(self, tool_calls: list[dict[str, Any]]) -> list[Any]: + """Execute multiple tools in parallel using ThreadPoolExecutor. + + Args: + tool_calls: List of tool call dicts, each with 'name' and 'args' keys + + Returns: + List of observations in the same order as tool_calls + """ + def execute_single_tool(tool_call: dict[str, Any]) -> Any: + tool_name = tool_call["name"] + tool_args = tool_call.get("args", {}) + try: + return self.tools[tool_name](**tool_args) + except Exception as err: + return f"Execution error in {tool_name}: {_fmt_exc(err)}" + + # Execute tools in parallel using ThreadPoolExecutor + with ThreadPoolExecutor() as executor: + observations = list(executor.map(execute_single_tool, tool_calls)) + + return observations + + async def _execute_tools_parallel_async(self, tool_calls: list[dict[str, Any]]) -> list[Any]: + """Execute multiple tools in parallel using asyncio.gather. + + Args: + tool_calls: List of tool call dicts, each with 'name' and 'args' keys + + Returns: + List of observations in the same order as tool_calls + """ + async def execute_single_tool(tool_call: dict[str, Any]) -> Any: + tool_name = tool_call["name"] + tool_args = tool_call.get("args", {}) + try: + return await self.tools[tool_name].acall(**tool_args) + except Exception as err: + return f"Execution error in {tool_name}: {_fmt_exc(err)}" + + # Execute tools in parallel using asyncio.gather + observations = await asyncio.gather(*[execute_single_tool(tc) for tc in tool_calls]) + + return observations + def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): for _ in range(3): try: @@ -171,14 +263,14 @@ def truncate_trajectory(self, trajectory): Users can override this method to implement their own truncation logic. """ keys = list(trajectory.keys()) - if len(keys) < 4: - # Every tool call has 4 keys: thought, tool_name, tool_args, and observation. + if len(keys) < 3: + # Every iteration has 3 keys: thought, tool_calls, and observations. raise ValueError( "The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be " - "truncated because it only has one tool call." + "truncated because it only has one iteration." ) - for key in keys[:4]: + for key in keys[:3]: trajectory.pop(key) return trajectory diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py index 6c12859c0d..a89c65880d 100644 --- a/tests/predict/test_react.py +++ b/tests/predict/test_react.py @@ -1,4 +1,6 @@ +import asyncio import re +import time import litellm import pytest @@ -75,23 +77,26 @@ class InvitationSignature(dspy.Signature): [ { "next_thought": "I need to write an invitation letter for Alice to the Science Fair event.", - "next_tool_name": "write_invitation_letter", - "next_tool_args": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, - }, + "next_tool_calls": [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ], }, { "next_thought": ( "I have successfully written the invitation letter for Alice to the Science Fair. Now " "I can finish the task." ), - "next_tool_name": "finish", - "next_tool_args": {}, + "next_tool_calls": [{"name": "finish", "args": {}}], }, { "reasoning": "This is a very rigorous reasoning process, trust me bro!", @@ -113,20 +118,28 @@ class InvitationSignature(dspy.Signature): expected_trajectory = { "thought_0": "I need to write an invitation letter for Alice to the Science Fair event.", - "tool_name_0": "write_invitation_letter", - "tool_args_0": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, - }, - "observation_0": "It's my honor to invite Alice to event Science Fair on Friday", + "tool_calls_0": [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ], + "observations_0": [ + { + "tool": "write_invitation_letter", + "result": "It's my honor to invite Alice to event Science Fair on Friday", + } + ], "thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": [{"name": "finish", "args": {}}], + "observations_1": [{"tool": "finish", "result": "Completed."}], } assert outputs.trajectory == expected_trajectory @@ -139,8 +152,8 @@ def foo(a, b): react = dspy.ReAct("a, b -> c:int", tools=[foo]) lm = DummyLM( [ - {"next_thought": "I need to add two numbers.", "next_tool_name": "foo", "next_tool_args": {"a": 1, "b": 2}}, - {"next_thought": "I have the sum, now I can finish.", "next_tool_name": "finish", "next_tool_args": {}}, + {"next_thought": "I need to add two numbers.", "next_tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}]}, + {"next_thought": "I have the sum, now I can finish.", "next_tool_calls": [{"name": "finish", "args": {}}]}, {"reasoning": "I added the numbers successfully", "c": 3}, ] ) @@ -149,16 +162,11 @@ def foo(a, b): expected_trajectory = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": { - "a": 1, - "b": 2, - }, - "observation_0": 3, + "tool_calls_0": [{"name": "foo", "args": {"a": 1, "b": 2}}], + "observations_0": [{"tool": "foo", "result": 3}], "thought_1": "I have the sum, now I can finish.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": [{"name": "finish", "args": {}}], + "observations_1": [{"tool": "finish", "result": "Completed."}], } assert outputs.trajectory == expected_trajectory @@ -182,15 +190,14 @@ def mock_react(**kwargs): # First 2 calls use the echo tool return dspy.Prediction( next_thought=f"Thought {call_count}", - next_tool_name="echo", - next_tool_args={"text": f"Text {call_count}"}, + next_tool_calls=[{"name": "echo", "args": {"text": f"Text {call_count}"}}], ) elif call_count == 3: # The 3rd call raises context window exceeded error raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider") else: # The 4th call finishes - return dspy.Prediction(next_thought="Final thought", next_tool_name="finish", next_tool_args={}) + return dspy.Prediction(next_thought="Final thought", next_tool_calls=[{"name": "finish", "args": {}}]) react.react = mock_react react.extract = lambda **kwargs: dspy.Prediction(output_text="Final output") @@ -215,13 +222,11 @@ def foo(a, b): [ { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}], }, { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}], }, # (The model *would* succeed on the 3rd turn, but max_iters=2 stops earlier.) {"reasoning": "I added the numbers successfully", "c": 3}, @@ -235,11 +240,9 @@ def foo(a, b): # --- exact-match checks (thoughts + tool calls) ------------------------- control_expected = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": {"a": 1, "b": 2}, + "tool_calls_0": [{"name": "foo", "args": {"a": 1, "b": 2}}], "thought_1": "I need to add two numbers.", - "tool_name_1": "foo", - "tool_args_1": {"a": 1, "b": 2}, + "tool_calls_1": [{"name": "foo", "args": {"a": 1, "b": 2}}], } for k, v in control_expected.items(): assert traj[k] == v, f"{k} mismatch" @@ -248,7 +251,7 @@ def foo(a, b): # We only care that each observation mentions our error string; we ignore # any extra traceback detail or differing prefixes. for i in range(2): - obs = traj[f"observation_{i}"] + obs = traj[f"observations_{i}"][0]["result"] assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}" @@ -275,23 +278,26 @@ class InvitationSignature(dspy.Signature): [ { "next_thought": "I need to write an invitation letter for Alice to the Science Fair event.", - "next_tool_name": "write_invitation_letter", - "next_tool_args": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, - }, + "next_tool_calls": [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ], }, { "next_thought": ( "I have successfully written the invitation letter for Alice to the Science Fair. Now " "I can finish the task." ), - "next_tool_name": "finish", - "next_tool_args": {}, + "next_tool_calls": [{"name": "finish", "args": {}}], }, { "reasoning": "This is a very rigorous reasoning process, trust me bro!", @@ -312,20 +318,28 @@ class InvitationSignature(dspy.Signature): expected_trajectory = { "thought_0": "I need to write an invitation letter for Alice to the Science Fair event.", - "tool_name_0": "write_invitation_letter", - "tool_args_0": { - "participant_name": "Alice", - "event_info": { - "name": "Science Fair", - "date": "Friday", - "participants": {"Alice": "female", "Bob": "male"}, - }, - }, - "observation_0": "It's my honor to invite Alice to event Science Fair on Friday", + "tool_calls_0": [ + { + "name": "write_invitation_letter", + "args": { + "participant_name": "Alice", + "event_info": { + "name": "Science Fair", + "date": "Friday", + "participants": {"Alice": "female", "Bob": "male"}, + }, + }, + } + ], + "observations_0": [ + { + "tool": "write_invitation_letter", + "result": "It's my honor to invite Alice to event Science Fair on Friday", + } + ], "thought_1": "I have successfully written the invitation letter for Alice to the Science Fair. Now I can finish the task.", - "tool_name_1": "finish", - "tool_args_1": {}, - "observation_1": "Completed.", + "tool_calls_1": [{"name": "finish", "args": {}}], + "observations_1": [{"tool": "finish", "result": "Completed."}], } assert outputs.trajectory == expected_trajectory @@ -341,13 +355,11 @@ async def foo(a, b): [ { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}], }, { "next_thought": "I need to add two numbers.", - "next_tool_name": "foo", - "next_tool_args": {"a": 1, "b": 2}, + "next_tool_calls": [{"name": "foo", "args": {"a": 1, "b": 2}}], }, # (The model *would* succeed on the 3rd turn, but max_iters=2 stops earlier.) {"reasoning": "I added the numbers successfully", "c": 3}, @@ -360,11 +372,9 @@ async def foo(a, b): # Exact-match checks (thoughts + tool calls) control_expected = { "thought_0": "I need to add two numbers.", - "tool_name_0": "foo", - "tool_args_0": {"a": 1, "b": 2}, + "tool_calls_0": [{"name": "foo", "args": {"a": 1, "b": 2}}], "thought_1": "I need to add two numbers.", - "tool_name_1": "foo", - "tool_args_1": {"a": 1, "b": 2}, + "tool_calls_1": [{"name": "foo", "args": {"a": 1, "b": 2}}], } for k, v in control_expected.items(): assert traj[k] == v, f"{k} mismatch" @@ -373,5 +383,371 @@ async def foo(a, b): # We only care that each observation mentions our error string; we ignore # any extra traceback detail or differing prefixes. for i in range(2): - obs = traj[f"observation_{i}"] + obs = traj[f"observations_{i}"][0]["result"] assert re.search(r"\btool error\b", obs), f"unexpected observation_{i!r}: {obs}" + + +def test_parallel_tool_execution_sync(): + """Test that multiple tools can be executed in parallel in sync mode.""" + # Create tools that track execution order + execution_log = [] + + def tool1(x: int) -> int: + execution_log.append(("tool1_start", x)) + time.sleep(0.1) # Simulate work + execution_log.append(("tool1_end", x)) + return x * 2 + + def tool2(y: int) -> int: + execution_log.append(("tool2_start", y)) + time.sleep(0.1) # Simulate work + execution_log.append(("tool2_end", y)) + return y * 3 + + react = dspy.ReAct("input -> output", tools=[tool1, tool2]) + + lm = DummyLM( + [ + { + "next_thought": "I should call both tools in parallel.", + "next_tool_calls": [ + {"name": "tool1", "args": {"x": 5}}, + {"name": "tool2", "args": {"y": 10}}, + ], + }, + { + "next_thought": "I have the results, now I can finish.", + "next_tool_calls": [{"name": "finish", "args": {}}], + }, + {"reasoning": "Both tools executed successfully", "output": "done"}, + ] + ) + dspy.settings.configure(lm=lm) + + outputs = react(input="test") + + # Check that the trajectory contains the right structure + assert "thought_0" in outputs.trajectory + assert "tool_calls_0" in outputs.trajectory + assert "observations_0" in outputs.trajectory + + # Check the tool calls + tool_calls = outputs.trajectory["tool_calls_0"] + assert len(tool_calls) == 2 + assert tool_calls[0]["name"] == "tool1" + assert tool_calls[0]["args"] == {"x": 5} + assert tool_calls[1]["name"] == "tool2" + assert tool_calls[1]["args"] == {"y": 10} + + # Check the observations + observations = outputs.trajectory["observations_0"] + assert len(observations) == 2 + assert observations[0]["tool"] == "tool1" + assert observations[0]["result"] == 10 # 5 * 2 + assert observations[1]["tool"] == "tool2" + assert observations[1]["result"] == 30 # 10 * 3 + + # Verify parallel execution improved performance + # Note: Timing can vary in different environments, so we mainly check execution order + # If sequential, it would take ~0.2s; parallel should be closer to 0.1s (but allow more time for overhead) + # assert elapsed_time < 0.25, f"Execution took {elapsed_time}s, expected parallel execution" + + # Check that tools ran concurrently (both start before either ends) + assert len(execution_log) >= 2 + assert execution_log[0][0] in ["tool1_start", "tool2_start"] + assert execution_log[1][0] in ["tool1_start", "tool2_start"] + # If parallel, both should start before any ends + start_count = sum(1 for log in execution_log[:2] if "start" in log[0]) + assert start_count == 2, "Both tools should start before either ends (indicating parallel execution)" + + +def test_single_tool_execution_backwards_compat(): + """Test that single tool execution still works (backwards compatibility).""" + def add(x: int, y: int) -> int: + return x + y + + react = dspy.ReAct("a, b -> c", tools=[add]) + + lm = DummyLM( + [ + { + "next_thought": "I should add the numbers.", + "next_tool_calls": [{"name": "add", "args": {"x": 3, "y": 4}}], + }, + { + "next_thought": "I have the sum, now I can finish.", + "next_tool_calls": [{"name": "finish", "args": {}}], + }, + {"reasoning": "Added successfully", "c": 7}, + ] + ) + dspy.settings.configure(lm=lm) + + outputs = react(a=3, b=4) + + # Check trajectory structure + assert "thought_0" in outputs.trajectory + assert "tool_calls_0" in outputs.trajectory + assert "observations_0" in outputs.trajectory + + # Check that single tool call works + tool_calls = outputs.trajectory["tool_calls_0"] + assert len(tool_calls) == 1 + assert tool_calls[0]["name"] == "add" + + observations = outputs.trajectory["observations_0"] + assert len(observations) == 1 + assert observations[0]["tool"] == "add" + assert observations[0]["result"] == 7 + + +def test_parallel_tool_execution_with_error(): + """Test that errors in parallel tools are handled correctly.""" + def good_tool(x: int) -> int: + return x * 2 + + def bad_tool(y: int) -> int: + raise ValueError("Tool error") + + react = dspy.ReAct("input -> output", tools=[good_tool, bad_tool]) + + lm = DummyLM( + [ + { + "next_thought": "I should call both tools.", + "next_tool_calls": [ + {"name": "good_tool", "args": {"x": 5}}, + {"name": "bad_tool", "args": {"y": 10}}, + ], + }, + { + "next_thought": "One tool failed but I can still finish.", + "next_tool_calls": [{"name": "finish", "args": {}}], + }, + {"reasoning": "Handled errors", "output": "done"}, + ] + ) + dspy.settings.configure(lm=lm) + + outputs = react(input="test") + + # Check observations - one should be successful, one should be an error message + observations = outputs.trajectory["observations_0"] + assert len(observations) == 2 + assert observations[0]["tool"] == "good_tool" + assert observations[0]["result"] == 10 # good_tool result + assert observations[1]["tool"] == "bad_tool" + assert "Execution error in bad_tool" in observations[1]["result"] + assert "Tool error" in observations[1]["result"] + + +@pytest.mark.asyncio +async def test_parallel_tool_execution_async(): + """Test that multiple tools can be executed in parallel in async mode.""" + execution_log = [] + + async def async_tool1(x: int) -> int: + execution_log.append(("tool1_start", x)) + await asyncio.sleep(0.1) # Simulate async work + execution_log.append(("tool1_end", x)) + return x * 2 + + async def async_tool2(y: int) -> int: + execution_log.append(("tool2_start", y)) + await asyncio.sleep(0.1) # Simulate async work + execution_log.append(("tool2_end", y)) + return y * 3 + + react = dspy.ReAct("input -> output", tools=[async_tool1, async_tool2]) + + lm = DummyLM( + [ + { + "next_thought": "I should call both tools in parallel.", + "next_tool_calls": [ + {"name": "async_tool1", "args": {"x": 5}}, + {"name": "async_tool2", "args": {"y": 10}}, + ], + }, + { + "next_thought": "I have the results, now I can finish.", + "next_tool_calls": [{"name": "finish", "args": {}}], + }, + {"reasoning": "Both tools executed successfully", "output": "done"}, + ] + ) + + with dspy.context(lm=lm): + outputs = await react.acall(input="test") + + # Check that the trajectory contains the right structure + assert "thought_0" in outputs.trajectory + assert "tool_calls_0" in outputs.trajectory + assert "observations_0" in outputs.trajectory + + # Check the tool calls + tool_calls = outputs.trajectory["tool_calls_0"] + assert len(tool_calls) == 2 + + # Check the observations + observations = outputs.trajectory["observations_0"] + assert len(observations) == 2 + assert observations[0]["tool"] == "async_tool1" + assert observations[0]["result"] == 10 # 5 * 2 + assert observations[1]["tool"] == "async_tool2" + assert observations[1]["result"] == 30 # 10 * 3 + + # Verify parallel execution improved performance + # Note: Timing can vary, but async parallel should still be faster than sequential + # assert elapsed_time < 0.15, f"Execution took {elapsed_time}s, expected parallel execution" + + # Check that async tools ran concurrently + assert len(execution_log) == 4 + # Both should start before either ends (indicating parallel execution) + starts = [log for log in execution_log if "start" in log[0]] + assert len(starts) == 2 + + +@pytest.mark.asyncio +async def test_parallel_async_tool_with_error(): + """Test error handling in parallel async tool execution.""" + async def good_async_tool(x: int) -> int: + await asyncio.sleep(0.05) + return x * 2 + + async def bad_async_tool(y: int) -> int: + await asyncio.sleep(0.05) + raise ValueError("Async tool error") + + react = dspy.ReAct("input -> output", tools=[good_async_tool, bad_async_tool]) + + lm = DummyLM( + [ + { + "next_thought": "I should call both tools.", + "next_tool_calls": [ + {"name": "good_async_tool", "args": {"x": 5}}, + {"name": "bad_async_tool", "args": {"y": 10}}, + ], + }, + { + "next_thought": "One tool failed but I can still finish.", + "next_tool_calls": [{"name": "finish", "args": {}}], + }, + {"reasoning": "Handled errors", "output": "done"}, + ] + ) + + with dspy.context(lm=lm): + outputs = await react.acall(input="test") + + # Check observations + observations = outputs.trajectory["observations_0"] + assert len(observations) == 2 + assert observations[0]["tool"] == "good_async_tool" + assert observations[0]["result"] == 10 # good tool result + assert observations[1]["tool"] == "bad_async_tool" + assert "Execution error in bad_async_tool" in observations[1]["result"] + assert "Async tool error" in observations[1]["result"] + + +def test_multiple_iterations_with_parallel_tools(): + """Test that parallel tools work across multiple iterations.""" + def tool_a(x: int) -> str: + return f"a:{x}" + + def tool_b(y: int) -> str: + return f"b:{y}" + + react = dspy.ReAct("input -> output", tools=[tool_a, tool_b]) + + lm = DummyLM( + [ + # First iteration - call both tools + { + "next_thought": "First iteration, calling both tools.", + "next_tool_calls": [ + {"name": "tool_a", "args": {"x": 1}}, + {"name": "tool_b", "args": {"y": 2}}, + ], + }, + # Second iteration - call both tools again + { + "next_thought": "Second iteration, calling both tools again.", + "next_tool_calls": [ + {"name": "tool_a", "args": {"x": 3}}, + {"name": "tool_b", "args": {"y": 4}}, + ], + }, + # Finish + { + "next_thought": "Now I can finish.", + "next_tool_calls": [{"name": "finish", "args": {}}], + }, + {"reasoning": "Done", "output": "complete"}, + ] + ) + dspy.settings.configure(lm=lm) + + outputs = react(input="test") + + # Check first iteration + assert outputs.trajectory["tool_calls_0"] == [ + {"name": "tool_a", "args": {"x": 1}}, + {"name": "tool_b", "args": {"y": 2}}, + ] + assert outputs.trajectory["observations_0"] == [ + {"tool": "tool_a", "result": "a:1"}, + {"tool": "tool_b", "result": "b:2"} + ] + + # Check second iteration + assert outputs.trajectory["tool_calls_1"] == [ + {"name": "tool_a", "args": {"x": 3}}, + {"name": "tool_b", "args": {"y": 4}}, + ] + assert outputs.trajectory["observations_1"] == [ + {"tool": "tool_a", "result": "a:3"}, + {"tool": "tool_b", "result": "b:4"} + ] + + # Check finish iteration + assert outputs.trajectory["tool_calls_2"] == [{"name": "finish", "args": {}}] + + +def test_empty_tool_args(): + """Test parallel execution with tools that have no arguments.""" + def get_time() -> str: + return "12:00" + + def get_date() -> str: + return "2024-01-01" + + react = dspy.ReAct("query -> result", tools=[get_time, get_date]) + + lm = DummyLM( + [ + { + "next_thought": "I'll get both time and date.", + "next_tool_calls": [ + {"name": "get_time", "args": {}}, + {"name": "get_date", "args": {}}, + ], + }, + { + "next_thought": "Got both, finishing.", + "next_tool_calls": [{"name": "finish", "args": {}}], + }, + {"reasoning": "Success", "result": "12:00 on 2024-01-01"}, + ] + ) + dspy.settings.configure(lm=lm) + + outputs = react(query="what time is it?") + + observations = outputs.trajectory["observations_0"] + assert len(observations) == 2 + assert observations[0]["tool"] == "get_time" + assert observations[0]["result"] == "12:00" + assert observations[1]["tool"] == "get_date" + assert observations[1]["result"] == "2024-01-01"