Skip to content

Commit 651cf1f

Browse files
committed
fix: enhance agent state management during resume, ensuring correct agent usage and saving tool outputs to session
1 parent d67bea9 commit 651cf1f

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

src/agents/run.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,12 @@ async def run(
828828
tool_output_guardrail_results: list[ToolOutputGuardrailResult] = []
829829

830830
current_span: Span[AgentSpanData] | None = None
831-
current_agent = starting_agent
831+
# When resuming from state, use the current agent from the state (which may be different
832+
# from starting_agent if a handoff occurred). Otherwise use starting_agent.
833+
if is_resumed_state and run_state is not None and run_state._current_agent is not None:
834+
current_agent = run_state._current_agent
835+
else:
836+
current_agent = starting_agent
832837
should_run_agent_start_hooks = True
833838

834839
# save only the new user input to the session, not the combined history
@@ -1435,7 +1440,12 @@ async def _start_streaming(
14351440
streamed_result.trace.start(mark_as_current=True)
14361441

14371442
current_span: Span[AgentSpanData] | None = None
1438-
current_agent = starting_agent
1443+
# When resuming from state, use the current agent from the state (which may be different
1444+
# from starting_agent if a handoff occurred). Otherwise use starting_agent.
1445+
if run_state is not None and run_state._current_agent is not None:
1446+
current_agent = run_state._current_agent
1447+
else:
1448+
current_agent = starting_agent
14391449
current_turn = 0
14401450
should_run_agent_start_hooks = True
14411451
tool_use_tracker = AgentToolUseTracker()
@@ -1498,6 +1508,70 @@ async def _start_streaming(
14981508
run_config=run_config,
14991509
hooks=hooks,
15001510
)
1511+
# Save tool outputs to session immediately after approval
1512+
# This ensures incomplete function calls in the session are completed
1513+
if session is not None and streamed_result.new_items:
1514+
# Save tool_call_output_item items (the outputs)
1515+
tool_output_items: list[RunItem] = [
1516+
item
1517+
for item in streamed_result.new_items
1518+
if item.type == "tool_call_output_item"
1519+
]
1520+
# Also find and save the corresponding function_call items
1521+
# (they might not be in session if the run was interrupted before saving)
1522+
output_call_ids = {
1523+
item.raw_item.get("call_id")
1524+
if isinstance(item.raw_item, dict)
1525+
else getattr(item.raw_item, "call_id", None)
1526+
for item in tool_output_items
1527+
}
1528+
tool_call_items: list[RunItem] = [
1529+
item
1530+
for item in streamed_result.new_items
1531+
if item.type == "tool_call_item"
1532+
and (
1533+
item.raw_item.get("call_id")
1534+
if isinstance(item.raw_item, dict)
1535+
else getattr(item.raw_item, "call_id", None)
1536+
)
1537+
in output_call_ids
1538+
]
1539+
# Check which items are already in the session to avoid duplicates
1540+
# Get existing items from session and extract their call_ids
1541+
existing_items = await session.get_items()
1542+
existing_call_ids: set[str] = set()
1543+
for existing_item in existing_items:
1544+
if isinstance(existing_item, dict):
1545+
item_type = existing_item.get("type")
1546+
if item_type in ("function_call", "function_call_output"):
1547+
existing_call_id = existing_item.get(
1548+
"call_id"
1549+
) or existing_item.get("callId")
1550+
if existing_call_id and isinstance(existing_call_id, str):
1551+
existing_call_ids.add(existing_call_id)
1552+
1553+
# Filter out items that are already in the session
1554+
items_to_save: list[RunItem] = []
1555+
for item in tool_call_items + tool_output_items:
1556+
item_call_id: str | None = None
1557+
if isinstance(item.raw_item, dict):
1558+
raw_call_id = item.raw_item.get("call_id") or item.raw_item.get(
1559+
"callId"
1560+
)
1561+
item_call_id = (
1562+
cast(str | None, raw_call_id) if raw_call_id else None
1563+
)
1564+
elif hasattr(item.raw_item, "call_id"):
1565+
item_call_id = cast(
1566+
str | None, getattr(item.raw_item, "call_id", None)
1567+
)
1568+
1569+
# Only save if not already in session
1570+
if item_call_id is None or item_call_id not in existing_call_ids:
1571+
items_to_save.append(item)
1572+
1573+
if items_to_save:
1574+
await AgentRunner._save_result_to_session(session, [], items_to_save)
15011575
# Clear the current step since we've handled it
15021576
run_state._current_step = None
15031577

0 commit comments

Comments
 (0)