Skip to content

Commit 148f281

Browse files
committed
move conversion functions to a utility module + rename _get_agent_card function + reduce use of Any type
1 parent 4786b2d commit 148f281

File tree

4 files changed

+243
-220
lines changed

4 files changed

+243
-220
lines changed

src/strands/agent/a2a_agent.py

Lines changed: 8 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,15 @@
55
"""
66

77
import logging
8-
from typing import Any, AsyncIterator, cast
9-
from uuid import uuid4
8+
from typing import Any, AsyncIterator
109

1110
import httpx
1211
from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
13-
from a2a.types import AgentCard, Part, Role, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart
14-
from a2a.types import Message as A2AMessage
12+
from a2a.types import AgentCard
1513

1614
from .._async import run_async
17-
from ..telemetry.metrics import EventLoopMetrics
15+
from ..multiagent.a2a.converters import convert_input_to_message, convert_response_to_agent_result
1816
from ..types.agent import AgentInput
19-
from ..types.content import ContentBlock, Message
2017
from .agent_result import AgentResult
2118

2219
logger = logging.getLogger(__name__)
@@ -79,7 +76,7 @@ def _get_client_factory(self, streaming: bool = False) -> ClientFactory:
7976
)
8077
return ClientFactory(config)
8178

82-
async def _discover_agent_card(self) -> AgentCard:
79+
async def _get_agent_card(self) -> AgentCard:
8380
"""Discover and cache the agent card from the remote endpoint.
8481
8582
Returns:
@@ -94,109 +91,7 @@ async def _discover_agent_card(self) -> AgentCard:
9491
logger.info("endpoint=<%s> | discovered agent card", self.endpoint)
9592
return self._agent_card
9693

97-
def _convert_input_to_message(self, prompt: AgentInput) -> A2AMessage:
98-
"""Convert AgentInput to A2A Message.
99-
100-
Args:
101-
prompt: Input in various formats (string, message list, or content blocks).
102-
103-
Returns:
104-
A2AMessage ready to send to the remote agent.
105-
106-
Raises:
107-
ValueError: If prompt format is unsupported.
108-
"""
109-
message_id = uuid4().hex
110-
111-
if isinstance(prompt, str):
112-
return A2AMessage(
113-
kind="message",
114-
role=Role.user,
115-
parts=[Part(TextPart(kind="text", text=prompt))],
116-
message_id=message_id,
117-
)
118-
119-
if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)):
120-
if "role" in prompt[0]:
121-
# Message list - extract last user message
122-
for msg in reversed(prompt):
123-
if msg.get("role") == "user":
124-
content = cast(list[ContentBlock], msg.get("content", []))
125-
parts = self._convert_content_blocks_to_parts(content)
126-
return A2AMessage(
127-
kind="message",
128-
role=Role.user,
129-
parts=parts,
130-
message_id=message_id,
131-
)
132-
else:
133-
# ContentBlock list
134-
parts = self._convert_content_blocks_to_parts(cast(list[ContentBlock], prompt))
135-
return A2AMessage(
136-
kind="message",
137-
role=Role.user,
138-
parts=parts,
139-
message_id=message_id,
140-
)
141-
142-
raise ValueError(f"Unsupported input type: {type(prompt)}")
143-
144-
def _convert_content_blocks_to_parts(self, content_blocks: list[ContentBlock]) -> list[Part]:
145-
"""Convert Strands ContentBlocks to A2A Parts.
146-
147-
Args:
148-
content_blocks: List of Strands content blocks.
149-
150-
Returns:
151-
List of A2A Part objects.
152-
"""
153-
parts = []
154-
for block in content_blocks:
155-
if "text" in block:
156-
parts.append(Part(TextPart(kind="text", text=block["text"])))
157-
return parts
158-
159-
def _convert_response_to_agent_result(self, response: Any) -> AgentResult:
160-
"""Convert A2A response to AgentResult.
161-
162-
Args:
163-
response: A2A response (either A2AMessage or tuple of task and update event).
164-
165-
Returns:
166-
AgentResult with extracted content and metadata.
167-
"""
168-
content: list[ContentBlock] = []
169-
170-
if isinstance(response, tuple) and len(response) == 2:
171-
task, update_event = response
172-
if update_event is None and task and hasattr(task, "artifacts"):
173-
# Non-streaming response: extract from task artifacts
174-
for artifact in task.artifacts:
175-
if hasattr(artifact, "parts"):
176-
for part in artifact.parts:
177-
if hasattr(part, "root") and hasattr(part.root, "text"):
178-
content.append({"text": part.root.text})
179-
elif isinstance(response, A2AMessage):
180-
# Direct message response
181-
for part in response.parts:
182-
if hasattr(part, "root") and hasattr(part.root, "text"):
183-
content.append({"text": part.root.text})
184-
185-
message: Message = {
186-
"role": "assistant",
187-
"content": content,
188-
}
189-
190-
return AgentResult(
191-
stop_reason="end_turn",
192-
message=message,
193-
metrics=EventLoopMetrics(),
194-
state={},
195-
)
196-
197-
async def _send_message(
198-
self, prompt: AgentInput, streaming: bool
199-
) -> AsyncIterator[tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | A2AMessage]:
94+
async def _send_message(self, prompt: AgentInput, streaming: bool) -> AsyncIterator[Any]:
20095
"""Send message to A2A agent.
20196
20297
Args:
@@ -212,9 +107,9 @@ async def _send_message(
212107
if prompt is None:
213108
raise ValueError("prompt is required for A2AAgent")
214109

215-
agent_card = await self._discover_agent_card()
110+
agent_card = await self._get_agent_card()
216111
client = self._get_client_factory(streaming=streaming).create(agent_card)
217-
message = self._convert_input_to_message(prompt)
112+
message = convert_input_to_message(prompt)
218113

219114
logger.info("endpoint=<%s> | %s message", self.endpoint, "streaming" if streaming else "sending")
220115
return client.send_message(message)
@@ -238,7 +133,7 @@ async def invoke_async(
238133
RuntimeError: If no response received from agent.
239134
"""
240135
async for event in await self._send_message(prompt, streaming=False):
241-
return self._convert_response_to_agent_result(event)
136+
return convert_response_to_agent_result(event)
242137

243138
raise RuntimeError("No response received from A2A agent")
244139

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""Conversion functions between Strands and A2A types."""
2+
3+
from typing import TypeAlias, cast
4+
from uuid import uuid4
5+
6+
from a2a.types import Message as A2AMessage
7+
from a2a.types import Part, Role, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart
8+
9+
from ...agent.agent_result import AgentResult
10+
from ...telemetry.metrics import EventLoopMetrics
11+
from ...types.agent import AgentInput
12+
from ...types.content import ContentBlock, Message
13+
14+
A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | A2AMessage
15+
16+
17+
def convert_input_to_message(prompt: AgentInput) -> A2AMessage:
18+
"""Convert AgentInput to A2A Message.
19+
20+
Args:
21+
prompt: Input in various formats (string, message list, or content blocks).
22+
23+
Returns:
24+
A2AMessage ready to send to the remote agent.
25+
26+
Raises:
27+
ValueError: If prompt format is unsupported.
28+
"""
29+
message_id = uuid4().hex
30+
31+
if isinstance(prompt, str):
32+
return A2AMessage(
33+
kind="message",
34+
role=Role.user,
35+
parts=[Part(TextPart(kind="text", text=prompt))],
36+
message_id=message_id,
37+
)
38+
39+
if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)):
40+
if "role" in prompt[0]:
41+
for msg in reversed(prompt):
42+
if msg.get("role") == "user":
43+
content = cast(list[ContentBlock], msg.get("content", []))
44+
parts = convert_content_blocks_to_parts(content)
45+
return A2AMessage(
46+
kind="message",
47+
role=Role.user,
48+
parts=parts,
49+
message_id=message_id,
50+
)
51+
else:
52+
parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt))
53+
return A2AMessage(
54+
kind="message",
55+
role=Role.user,
56+
parts=parts,
57+
message_id=message_id,
58+
)
59+
60+
raise ValueError(f"Unsupported input type: {type(prompt)}")
61+
62+
63+
def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]:
64+
"""Convert Strands ContentBlocks to A2A Parts.
65+
66+
Args:
67+
content_blocks: List of Strands content blocks.
68+
69+
Returns:
70+
List of A2A Part objects.
71+
"""
72+
parts = []
73+
for block in content_blocks:
74+
if "text" in block:
75+
parts.append(Part(TextPart(kind="text", text=block["text"])))
76+
return parts
77+
78+
79+
def convert_response_to_agent_result(response: A2AResponse) -> AgentResult:
80+
"""Convert A2A response to AgentResult.
81+
82+
Args:
83+
response: A2A response (either A2AMessage or tuple of task and update event).
84+
85+
Returns:
86+
AgentResult with extracted content and metadata.
87+
"""
88+
content: list[ContentBlock] = []
89+
90+
if isinstance(response, tuple) and len(response) == 2:
91+
task, update_event = response
92+
if update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None:
93+
for artifact in task.artifacts:
94+
if hasattr(artifact, "parts"):
95+
for part in artifact.parts:
96+
if hasattr(part, "root") and hasattr(part.root, "text"):
97+
content.append({"text": part.root.text})
98+
elif isinstance(response, A2AMessage):
99+
for part in response.parts:
100+
if hasattr(part, "root") and hasattr(part.root, "text"):
101+
content.append({"text": part.root.text})
102+
103+
message: Message = {
104+
"role": "assistant",
105+
"content": content,
106+
}
107+
108+
return AgentResult(
109+
stop_reason="end_turn",
110+
message=message,
111+
metrics=EventLoopMetrics(),
112+
state={},
113+
)

0 commit comments

Comments
 (0)