Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 24 additions & 211 deletions dapr_agents/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dapr_agents.agents.memory_store import MemoryStore
from dapr_agents.agents.utils.text_printer import ColorTextFormatter
from dapr_agents.types import MessagePlaceHolder, BaseMessage, ToolExecutionRecord
from dapr_agents.types import BaseMessage, ToolExecutionRecord
from dapr_agents.tool.executor import AgentToolExecutor
from dapr_agents.prompt.base import PromptTemplateBase
from dapr_agents.prompt import ChatPromptTemplate
from dapr_agents.prompt.agent_prompt import Prompt
from dapr_agents.prompt.agent_prompt_context import Context
from dapr_agents.tool.base import AgentTool
import re
import json
Expand All @@ -19,8 +19,6 @@
Any,
Union,
Callable,
Literal,
ClassVar,
)
from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict
from dapr_agents.llm.chat import ChatClientBase
Expand Down Expand Up @@ -59,17 +57,10 @@ class AgentBase(BaseModel, ABC):
instructions: Optional[List[str]] = Field(
default=None, description="Instructions guiding the agent's tasks."
)
system_prompt: Optional[str] = Field(
default=None,
description="A custom system prompt, overriding name, role, goal, and instructions.",
)
llm: Optional[ChatClientBase] = Field(
default=None,
description="Language model client for generating responses.",
)
prompt_template: Optional[PromptTemplateBase] = Field(
default=None, description="The prompt template for the agent."
)
# TODO: we need to add RBAC to tools to define what users and/or agents can use what tool(s).
tools: List[Union[AgentTool, Callable]] = Field(
default_factory=list,
Expand All @@ -86,13 +77,6 @@ class AgentBase(BaseModel, ABC):
max_iterations: int = Field(
default=10, description="Max iterations for conversation cycles."
)
# TODO: we should have a system_template, prompt_template, and response_template, or better separation here.
# If we have something like a customer service agent, we want diff templates for different types of interactions.
# In future, we could also have a way to dynamically change the template based on the context of the interaction.
template_format: Literal["f-string", "jinja2"] = Field(
default="jinja2",
description="The format used for rendering the prompt template.",
)
memory_store: Optional["MemoryStore"] = Field(
default=None,
description=(
Expand All @@ -101,29 +85,11 @@ class AgentBase(BaseModel, ABC):
"For persistent storage, specify the name of the Dapr State Store to use. "
),
)
registry_store: Optional[str] = Field(
default=None,
description="Agent registry store name for storing static agent information. Defaults to memory_store state store name if not provided.",
prompt: Optional[Prompt] = Field(
default_factory=Prompt,
description="Prompt handles how agent prompts (system messages and full chat context) are built and formatted before being sent to the LLM.",
)

DEFAULT_SYSTEM_PROMPT: ClassVar[str]
"""Default f-string template; placeholders will be swapped to Jinja if needed."""
DEFAULT_SYSTEM_PROMPT = """
# Today's date is: {date}

## Name
Your name is {name}.

## Role
Your role is {role}.

## Goal
{goal}.

## Instructions
{instructions}.
""".strip()

_tool_executor: AgentToolExecutor = PrivateAttr()
_text_formatter: ColorTextFormatter = PrivateAttr(
default_factory=ColorTextFormatter
Expand Down Expand Up @@ -173,6 +139,11 @@ def model_post_init(self, __context: Any) -> None:
# Initialize LLM if not provided
if self.llm is None:
self.llm = get_default_llm()
elif getattr(self.llm, "prompt_template", None):
# Agent owns the prompt; warn if LLM has its own template set
logger.warning(
"LLM prompt template is set, but agent will use its own prompt template."
)

# Initialize storage if not provided (in-memory by default)
if self.memory_store is None:
Expand Down Expand Up @@ -214,84 +185,21 @@ def model_post_init(self, __context: Any) -> None:
agent_metadata=self._serialize_metadata(agent_metadata),
)

# Centralize prompt template selection logic
self.prompt_template = self._initialize_prompt_template()
# Ensure LLM client and agent both reference the same template
if self.llm is not None:
self.llm.prompt_template = self.prompt_template

self._validate_prompt_template()
self.prefill_agent_attributes()
if self.prompt is None:
self.prompt = Prompt()
if getattr(self.prompt, "context", None) is None:
self.prompt.context = Context()
self.prompt.context.name = self.name
self.prompt.context.role = self.role
self.prompt.context.goal = self.goal
self.prompt.context.instructions = self.instructions

# Set up graceful shutdown
self._shutdown_event = asyncio.Event()
self._setup_signal_handlers()

super().model_post_init(__context)

def _initialize_prompt_template(self) -> PromptTemplateBase:
"""
Determines which prompt template to use for the agent:
1. If the user supplied one, use it.
2. Else if the LLM client already has one, adopt that.
3. Else generate a system_prompt and ChatPromptTemplate from agent attributes.

Returns:
PromptTemplateBase: The selected or constructed prompt template.
"""
# 1) User provided one?
if self.prompt_template:
logger.debug("🛠️ Using provided agent.prompt_template")
return self.prompt_template

# 2) LLM client has one?
if (
self.llm
and hasattr(self.llm, "prompt_template")
and self.llm.prompt_template
):
logger.debug("🔄 Syncing from llm.prompt_template")
return self.llm.prompt_template

# 3) Build from system_prompt or attributes
if not self.system_prompt:
logger.debug("⚙️ Constructing system_prompt from attributes")
self.system_prompt = self.construct_system_prompt()

logger.debug("⚙️ Building ChatPromptTemplate from system_prompt")
return self.construct_prompt_template()

def _collect_template_attrs(self) -> tuple[Dict[str, str], List[str]]:
"""
Collect agent attributes for prompt template pre-filling and warn about unused ones.
- valid: attributes set on self and declared in prompt_template.input_variables.
- unused: attributes set on self but not present in the template.
Returns:
(valid, unused): Tuple of dict of valid attrs and list of unused attr names.
"""
attrs = ["name", "role", "goal", "instructions"]
valid: Dict[str, str] = {}
unused: List[str] = []
if not self.prompt_template or not hasattr(
self.prompt_template, "input_variables"
):
return valid, attrs # No template, all attrs are unused
original = set(self.prompt_template.input_variables)

for attr in attrs:
val = getattr(self, attr, None)
if val is None:
continue
if attr in original:
# Only join instructions if it's a list and the template expects it
if attr == "instructions" and isinstance(val, list):
valid[attr] = "\n".join(val)
else:
valid[attr] = str(val)
else:
unused.append(attr)
return valid, unused

def _setup_signal_handlers(self):
"""Set up signal handlers for graceful shutdown"""
try:
Expand All @@ -306,28 +214,6 @@ def _signal_handler(self, signum, frame):
print(f"\nReceived signal {signum}. Shutting down gracefully...")
self._shutdown_event.set()

def _validate_prompt_template(self) -> None:
"""
Ensures chat_history is always available, injects any declared attributes,
and warns if the user set attributes that aren't in the template.
"""
if not self.prompt_template:
return

# Always make chat_history available
vars_set = set(self.prompt_template.input_variables) | {"chat_history"}

# Inject any attributes the template declares
valid_attrs, unused_attrs = self._collect_template_attrs()
vars_set |= set(valid_attrs.keys())
self.prompt_template.input_variables = list(vars_set)

if unused_attrs:
logger.warning(
"Agent attributes set but not referenced in prompt_template: "
f"{', '.join(unused_attrs)}. Consider adding them to input_variables."
)

@property
def tool_executor(self) -> AgentToolExecutor:
"""Returns the client to execute and manage tools, ensuring it's accessible but read-only."""
Expand Down Expand Up @@ -370,101 +256,28 @@ def run(self, input_data: Union[str, Dict[str, Any]]) -> Any:
"""
pass

def prefill_agent_attributes(self) -> None:
"""
Pre-fill prompt_template with agent attributes if specified in `input_variables`.
Uses _collect_template_attrs to avoid duplicate logic and ensure consistency.
"""
if not self.prompt_template:
return

# Re-use our helper to split valid vs. unused
valid_attrs, unused_attrs = self._collect_template_attrs()

if unused_attrs:
logger.warning(
"Agent attributes set but not used in prompt_template: "
f"{', '.join(unused_attrs)}. Consider adding them to input_variables."
)

if valid_attrs:
self.prompt_template = self.prompt_template.pre_fill_variables(
**valid_attrs
)
logger.debug(f"Pre-filled template with: {list(valid_attrs.keys())}")
else:
logger.debug("No prompt_template variables needed pre-filling.")

def construct_system_prompt(self) -> str:
"""
Build the system prompt for the agent using a single template string.
- Fills in the current date.
- Leaves placeholders for name, role, goal, and instructions as variables (instructions only if set).
- Converts placeholders to Jinja2 syntax if requested.

Returns:
str: The formatted system prompt string.
"""
# Only fill in the date; leave all other placeholders as variables
instructions_placeholder = "{instructions}" if self.instructions else ""
filled = self.DEFAULT_SYSTEM_PROMPT.format(
date=datetime.now().strftime("%B %d, %Y"),
name="{name}",
role="{role}",
goal="{goal}",
instructions=instructions_placeholder,
)

# If using Jinja2, swap braces for all placeholders
if self.template_format == "jinja2":
# Replace every {foo} with {{foo}}
return re.sub(r"\{(\w+)\}", r"{{\1}}", filled)
else:
return filled

def construct_prompt_template(self) -> ChatPromptTemplate:
"""
Constructs a ChatPromptTemplate that includes the system prompt and a placeholder for chat history.
Ensures that the template is flexible and adaptable to dynamically handle pre-filled variables.

Returns:
ChatPromptTemplate: A formatted prompt template for the agent.
"""
# Construct the system prompt if not provided
system_prompt = self.system_prompt or self.construct_system_prompt()

# Create the template with placeholders for system message and chat history
return ChatPromptTemplate.from_messages(
messages=[
("system", system_prompt),
MessagePlaceHolder(variable_name="chat_history"),
],
template_format=self.template_format,
)

def construct_messages(
self, input_data: Union[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Constructs and formats initial messages based on input type, passing chat_history as a list, without mutating self.prompt_template.
Constructs and formats initial messages based on input type, passing chat_history as a list, without mutating the prompt template.

Args:
input_data (Union[str, Dict[str, Any]]): User input, either as a string or dictionary.

Returns:
List[Dict[str, Any]]: List of formatted messages, including the user message if input_data is a string.
"""
if not self.prompt_template:
has_prompt = bool(self.prompt and getattr(self.prompt, "template", None))
if not has_prompt:
raise ValueError(
"Prompt template must be initialized before constructing messages."
)

chat_history = self.get_chat_history() # List[Dict[str, Any]]

if isinstance(input_data, str):
formatted_messages = self.prompt_template.format_prompt(
chat_history=chat_history
)
formatted_messages = self.prompt.template.format(chat_history=chat_history)
if isinstance(formatted_messages, list):
user_message = {"role": "user", "content": input_data}
return formatted_messages + [user_message]
Expand All @@ -478,7 +291,7 @@ def construct_messages(
input_vars = dict(input_data)
if "chat_history" not in input_vars:
input_vars["chat_history"] = chat_history
formatted_messages = self.prompt_template.format_prompt(**input_vars)
formatted_messages = self.prompt.template.format(**input_vars)
if isinstance(formatted_messages, list):
return formatted_messages
else:
Expand Down
8 changes: 3 additions & 5 deletions dapr_agents/agents/durableagent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def _construct_messages_with_instance_history(
List of formatted messages with proper sequence
"""
additional_context_messages: List[Dict[str, Any]] = []
if not self.prompt_template:
if not self.prompt.template:
raise ValueError(
"Prompt template must be initialized before constructing messages."
)
Expand Down Expand Up @@ -1015,9 +1015,7 @@ def _construct_messages_with_instance_history(
chat_history.extend(additional_context_messages)

if isinstance(input_data, str):
formatted_messages = self.prompt_template.format_prompt(
chat_history=chat_history
)
formatted_messages = self.prompt.template.format(chat_history=chat_history)
if isinstance(formatted_messages, list):
user_message = {"role": "user", "content": input_data}
return formatted_messages + [user_message]
Expand All @@ -1030,7 +1028,7 @@ def _construct_messages_with_instance_history(
input_vars = dict(input_data)
if "chat_history" not in input_vars:
input_vars["chat_history"] = chat_history
formatted_messages = self.prompt_template.format_prompt(**input_vars)
formatted_messages = self.prompt.template.format(**input_vars)
if isinstance(formatted_messages, list):
return formatted_messages
else:
Expand Down
2 changes: 1 addition & 1 deletion dapr_agents/llm/dapr/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def generate(
if input_data:
if not self.prompt_template:
raise ValueError("input_data provided but no prompt_template is set.")
messages = self.prompt_template.format_prompt(**input_data)
messages = self.prompt_template.format(**input_data)

if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")
Expand Down
2 changes: 1 addition & 1 deletion dapr_agents/llm/huggingface/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def generate(
if not self.prompt_template:
raise ValueError("No prompt_template set for input_data usage.")
logger.info("Formatting messages via prompt_template.")
messages = self.prompt_template.format_prompt(**input_data)
messages = self.prompt_template.format(**input_data)

if not messages:
raise ValueError("Either messages or input_data must be provided.")
Expand Down
2 changes: 1 addition & 1 deletion dapr_agents/llm/nvidia/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def generate(
if not self.prompt_template:
raise ValueError("input_data provided but no prompt_template is set.")
logger.info("Formatting messages via prompt_template.")
messages = self.prompt_template.format_prompt(**input_data)
messages = self.prompt_template.format(**input_data)

if not messages:
raise ValueError("Either 'messages' or 'input_data' must be provided.")
Expand Down
Loading