Skip to content

Commit ecc109f

Browse files
committed
change AgentBase from abstract class to protocol + remove invocation_state and structured_output_model
1 parent b12d327 commit ecc109f

File tree

6 files changed

+49
-91
lines changed

6 files changed

+49
-91
lines changed

src/strands/agent/agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
from ..types.tools import ToolResult, ToolUse
6565
from ..types.traces import AttributeValue
6666
from .agent_result import AgentResult
67-
from .base import AgentBase
6867
from .conversation_manager import (
6968
ConversationManager,
7069
SlidingWindowConversationManager,
@@ -89,7 +88,7 @@ class _DefaultCallbackHandlerSentinel:
8988
_DEFAULT_AGENT_ID = "default"
9089

9190

92-
class Agent(AgentBase):
91+
class Agent:
9392
"""Core Agent implementation.
9493
9594
An agent orchestrates the following workflow:

src/strands/agent/base.py

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,115 +3,74 @@
33
Defines the minimal interface that all agent types must implement.
44
"""
55

6-
from abc import ABC, abstractmethod
7-
from typing import Any, AsyncIterator, Type
8-
9-
from pydantic import BaseModel
6+
from typing import Any, AsyncIterator, Protocol, runtime_checkable
107

118
from ..types.agent import AgentInput
129
from .agent_result import AgentResult
1310
from .state import AgentState
1411

1512

16-
class AgentBase(ABC):
17-
"""Abstract interface for all agent types in Strands.
13+
@runtime_checkable
14+
class AgentBase(Protocol):
15+
"""Protocol defining the interface for all agent types in Strands.
1816
19-
This interface defines the minimal contract that all agent implementations
17+
This protocol defines the minimal contract that all agent implementations
2018
must satisfy.
2119
"""
2220

23-
@property
24-
@abstractmethod
25-
def agent_id(self) -> str:
26-
"""Unique identifier for the agent.
27-
28-
Returns:
29-
Unique string identifier for this agent instance.
30-
"""
31-
pass
32-
33-
@property
34-
@abstractmethod
35-
def name(self) -> str:
36-
"""Human-readable name of the agent.
21+
agent_id: str
22+
"""Unique identifier for the agent."""
3723

38-
Returns:
39-
Display name for the agent.
40-
"""
41-
pass
42-
43-
@property
44-
@abstractmethod
45-
def state(self) -> AgentState:
46-
"""Current state of the agent.
24+
name: str
25+
"""Human-readable name of the agent."""
4726

48-
Returns:
49-
AgentState object containing stateful information.
50-
"""
51-
pass
27+
state: AgentState
28+
"""Current state of the agent."""
5229

53-
@abstractmethod
5430
async def invoke_async(
5531
self,
5632
prompt: AgentInput = None,
57-
*,
58-
invocation_state: dict[str, Any] | None = None,
59-
structured_output_model: Type[BaseModel] | None = None,
6033
**kwargs: Any,
6134
) -> AgentResult:
6235
"""Asynchronously invoke the agent with the given prompt.
6336
6437
Args:
6538
prompt: Input to the agent.
66-
invocation_state: Optional state to pass to the agent invocation.
67-
structured_output_model: Optional Pydantic model for structured output.
68-
**kwargs: Additional provider-specific arguments.
39+
**kwargs: Additional arguments.
6940
7041
Returns:
7142
AgentResult containing the agent's response.
7243
"""
73-
pass
44+
...
7445

75-
@abstractmethod
7646
def __call__(
7747
self,
7848
prompt: AgentInput = None,
79-
*,
80-
invocation_state: dict[str, Any] | None = None,
81-
structured_output_model: Type[BaseModel] | None = None,
8249
**kwargs: Any,
8350
) -> AgentResult:
8451
"""Synchronously invoke the agent with the given prompt.
8552
8653
Args:
8754
prompt: Input to the agent.
88-
invocation_state: Optional state to pass to the agent invocation.
89-
structured_output_model: Optional Pydantic model for structured output.
90-
**kwargs: Additional provider-specific arguments.
55+
**kwargs: Additional arguments.
9156
9257
Returns:
9358
AgentResult containing the agent's response.
9459
"""
95-
pass
60+
...
9661

97-
@abstractmethod
9862
def stream_async(
9963
self,
10064
prompt: AgentInput = None,
101-
*,
102-
invocation_state: dict[str, Any] | None = None,
103-
structured_output_model: Type[BaseModel] | None = None,
10465
**kwargs: Any,
10566
) -> AsyncIterator[Any]:
10667
"""Stream agent execution asynchronously.
10768
10869
Args:
10970
prompt: Input to the agent.
110-
invocation_state: Optional state to pass to the agent invocation.
111-
structured_output_model: Optional Pydantic model for structured output.
112-
**kwargs: Additional provider-specific arguments.
71+
**kwargs: Additional arguments.
11372
11473
Yields:
11574
Events representing the streaming execution.
11675
"""
117-
pass
76+
...

src/strands/multiagent/graph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from opentelemetry import trace as trace_api
2525

2626
from .._async import run_async
27-
from ..agent import Agent
27+
from ..agent import Agent, AgentBase
2828
from ..agent.state import AgentState
2929
from ..experimental.hooks.multiagent import (
3030
AfterMultiAgentInvocationEvent,
@@ -154,7 +154,7 @@ class GraphNode:
154154
"""
155155

156156
node_id: str
157-
executor: Agent | MultiAgentBase
157+
executor: AgentBase | MultiAgentBase
158158
dependencies: set["GraphNode"] = field(default_factory=set)
159159
execution_status: Status = Status.PENDING
160160
result: NodeResult | None = None
@@ -199,7 +199,7 @@ def __eq__(self, other: Any) -> bool:
199199

200200

201201
def _validate_node_executor(
202-
executor: Agent | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None
202+
executor: AgentBase | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None
203203
) -> None:
204204
"""Validate a node executor for graph compatibility.
205205
@@ -238,8 +238,8 @@ def __init__(self) -> None:
238238
self._session_manager: Optional[SessionManager] = None
239239
self._hooks: Optional[list[HookProvider]] = None
240240

241-
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
242-
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
241+
def add_node(self, executor: AgentBase | MultiAgentBase, node_id: str | None = None) -> GraphNode:
242+
"""Add an AgentBase or MultiAgentBase instance as a node to the graph."""
243243
_validate_node_executor(executor, self.nodes)
244244

245245
# Auto-generate node_id if not provided

src/strands/multiagent/swarm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from opentelemetry import trace as trace_api
2424

2525
from .._async import run_async
26-
from ..agent import Agent
26+
from ..agent import Agent, AgentBase
2727
from ..agent.state import AgentState
2828
from ..experimental.hooks.multiagent import (
2929
AfterMultiAgentInvocationEvent,
@@ -57,7 +57,7 @@ class SwarmNode:
5757
"""Represents a node (e.g. Agent) in the swarm."""
5858

5959
node_id: str
60-
executor: Agent
60+
executor: AgentBase
6161
_initial_messages: Messages = field(default_factory=list, init=False)
6262
_initial_state: AgentState = field(default_factory=AgentState, init=False)
6363

@@ -212,9 +212,9 @@ class Swarm(MultiAgentBase):
212212

213213
def __init__(
214214
self,
215-
nodes: list[Agent],
215+
nodes: list[AgentBase],
216216
*,
217-
entry_point: Agent | None = None,
217+
entry_point: AgentBase | None = None,
218218
max_handoffs: int = 20,
219219
max_iterations: int = 20,
220220
execution_timeout: float = 900.0,
@@ -229,8 +229,8 @@ def __init__(
229229
230230
Args:
231231
id : Unique swarm id (default: None)
232-
nodes: List of nodes (e.g. Agent) to include in the swarm
233-
entry_point: Agent to start with. If None, uses the first agent (default: None)
232+
nodes: List of nodes (e.g. AgentBase) to include in the swarm
233+
entry_point: AgentBase to start with. If None, uses the first agent (default: None)
234234
max_handoffs: Maximum handoffs to agents and users (default: 20)
235235
max_iterations: Maximum node executions within the swarm (default: 20)
236236
execution_timeout: Total execution timeout in seconds (default: 900.0)
@@ -425,7 +425,7 @@ async def _stream_with_timeout(
425425
except asyncio.TimeoutError as err:
426426
raise Exception(timeout_message) from err
427427

428-
def _setup_swarm(self, nodes: list[Agent]) -> None:
428+
def _setup_swarm(self, nodes: list[AgentBase]) -> None:
429429
"""Initialize swarm configuration."""
430430
# Validate nodes before setup
431431
self._validate_swarm(nodes)
@@ -467,7 +467,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
467467
first_node = next(iter(self.nodes.keys()))
468468
logger.debug("entry_point=<%s> | using first node as entry point", first_node)
469469

470-
def _validate_swarm(self, nodes: list[Agent]) -> None:
470+
def _validate_swarm(self, nodes: list[AgentBase]) -> None:
471471
"""Validate swarm structure and nodes."""
472472
# Check for duplicate object instances
473473
seen_instances = set()

src/strands/session/repository_session_manager.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .session_repository import SessionRepository
1818

1919
if TYPE_CHECKING:
20-
from ..agent.agent import Agent
20+
from ..agent.base import AgentBase
2121
from ..multiagent.base import MultiAgentBase
2222

2323
logger = logging.getLogger(__name__)
@@ -58,12 +58,12 @@ def __init__(
5858
# Keep track of the latest message of each agent in case we need to redact it.
5959
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}
6060

61-
def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
61+
def append_message(self, message: Message, agent: "AgentBase", **kwargs: Any) -> None:
6262
"""Append a message to the agent's session.
6363
6464
Args:
6565
message: Message to add to the agent in the session
66-
agent: Agent to append the message to
66+
agent: AgentBase to append the message to
6767
**kwargs: Additional keyword arguments for future extensibility.
6868
"""
6969
# Calculate the next index (0 if this is the first message, otherwise increment the previous index)
@@ -77,12 +77,12 @@ def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> Non
7777
self._latest_agent_message[agent.agent_id] = session_message
7878
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
7979

80-
def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None:
80+
def redact_latest_message(self, redact_message: Message, agent: "AgentBase", **kwargs: Any) -> None:
8181
"""Redact the latest message appended to the session.
8282
8383
Args:
8484
redact_message: New message to use that contains the redact content
85-
agent: Agent to apply the message redaction to
85+
agent: AgentBase to apply the message redaction to
8686
**kwargs: Additional keyword arguments for future extensibility.
8787
"""
8888
latest_agent_message = self._latest_agent_message[agent.agent_id]
@@ -91,23 +91,23 @@ def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwarg
9191
latest_agent_message.redact_message = redact_message
9292
return self.session_repository.update_message(self.session_id, agent.agent_id, latest_agent_message)
9393

94-
def sync_agent(self, agent: "Agent", **kwargs: Any) -> None:
94+
def sync_agent(self, agent: "AgentBase", **kwargs: Any) -> None:
9595
"""Serialize and update the agent into the session repository.
9696
9797
Args:
98-
agent: Agent to sync to the session.
98+
agent: AgentBase to sync to the session.
9999
**kwargs: Additional keyword arguments for future extensibility.
100100
"""
101101
self.session_repository.update_agent(
102102
self.session_id,
103103
SessionAgent.from_agent(agent),
104104
)
105105

106-
def initialize(self, agent: "Agent", **kwargs: Any) -> None:
106+
def initialize(self, agent: "AgentBase", **kwargs: Any) -> None:
107107
"""Initialize an agent with a session.
108108
109109
Args:
110-
agent: Agent to initialize from the session
110+
agent: AgentBase to initialize from the session
111111
**kwargs: Additional keyword arguments for future extensibility.
112112
"""
113113
if agent.agent_id in self._latest_agent_message:

src/strands/session/session_manager.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..types.content import Message
1515

1616
if TYPE_CHECKING:
17-
from ..agent.agent import Agent
17+
from ..agent.base import AgentBase
1818
from ..multiagent.base import MultiAgentBase
1919

2020
logger = logging.getLogger(__name__)
@@ -48,40 +48,40 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
4848
registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source))
4949

5050
@abstractmethod
51-
def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None:
51+
def redact_latest_message(self, redact_message: Message, agent: "AgentBase", **kwargs: Any) -> None:
5252
"""Redact the message most recently appended to the agent in the session.
5353
5454
Args:
5555
redact_message: New message to use that contains the redact content
56-
agent: Agent to apply the message redaction to
56+
agent: AgentBase to apply the message redaction to
5757
**kwargs: Additional keyword arguments for future extensibility.
5858
"""
5959

6060
@abstractmethod
61-
def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
61+
def append_message(self, message: Message, agent: "AgentBase", **kwargs: Any) -> None:
6262
"""Append a message to the agent's session.
6363
6464
Args:
6565
message: Message to add to the agent in the session
66-
agent: Agent to append the message to
66+
agent: AgentBase to append the message to
6767
**kwargs: Additional keyword arguments for future extensibility.
6868
"""
6969

7070
@abstractmethod
71-
def sync_agent(self, agent: "Agent", **kwargs: Any) -> None:
71+
def sync_agent(self, agent: "AgentBase", **kwargs: Any) -> None:
7272
"""Serialize and sync the agent with the session storage.
7373
7474
Args:
75-
agent: Agent who should be synchronized with the session storage
75+
agent: AgentBase who should be synchronized with the session storage
7676
**kwargs: Additional keyword arguments for future extensibility.
7777
"""
7878

7979
@abstractmethod
80-
def initialize(self, agent: "Agent", **kwargs: Any) -> None:
80+
def initialize(self, agent: "AgentBase", **kwargs: Any) -> None:
8181
"""Initialize an agent with a session.
8282
8383
Args:
84-
agent: Agent to initialize
84+
agent: AgentBase to initialize
8585
**kwargs: Additional keyword arguments for future extensibility.
8686
"""
8787

0 commit comments

Comments
 (0)