diff --git a/veadk/agent.py b/veadk/agent.py index b3f14365..cfa20260 100644 --- a/veadk/agent.py +++ b/veadk/agent.py @@ -15,7 +15,7 @@ from __future__ import annotations import os -from typing import Optional, Union, AsyncGenerator +from typing import AsyncGenerator, Optional, Union # If user didn't set LITELLM_LOCAL_MODEL_COST_MAP, set it to True # to enable local model cost map. @@ -24,12 +24,15 @@ if not os.getenv("LITELLM_LOCAL_MODEL_COST_MAP"): os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -from google.adk.agents import LlmAgent, RunConfig, InvocationContext +from google.adk.agents import InvocationContext, LlmAgent, RunConfig from google.adk.agents.base_agent import BaseAgent from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.agents.llm_agent import InstructionProvider, ToolUnion from google.adk.agents.run_config import StreamingMode from google.adk.events import Event, EventActions +from google.adk.flows.llm_flows.auto_flow import AutoFlow +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.flows.llm_flows.single_flow import SingleFlow from google.adk.models.lite_llm import LiteLlm from google.adk.runners import Runner from google.genai import types @@ -53,8 +56,8 @@ from veadk.prompts.prompt_manager import BasePromptManager from veadk.tracing.base_tracer import BaseTracer from veadk.utils.logger import get_logger -from veadk.utils.patches import patch_asyncio, patch_tracer from veadk.utils.misc import check_litellm_version +from veadk.utils.patches import patch_asyncio, patch_tracer from veadk.version import VERSION patch_tracer() @@ -118,6 +121,8 @@ class Agent(LlmAgent): enable_responses: bool = False + enable_supervisor_flow: bool = False + context_cache_config: Optional[ContextCacheConfig] = None run_processor: Optional[BaseRunProcessor] = Field(default=None, exclude=True) @@ -292,6 +297,28 @@ def model_post_init(self, __context: Any) -> None: f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}" ) + @property + def _llm_flow(self) -> BaseLlmFlow: + if ( + self.disallow_transfer_to_parent + and self.disallow_transfer_to_peers + and not self.sub_agents + ): + from veadk.flows.supervisor_single_flow import SupervisorSingleFlow + + if self.enable_supervisor_flow: + logger.debug(f"Enable supervisor flow for agent: {self.name}") + return SupervisorSingleFlow(supervised_agent=self) + else: + return SingleFlow() + else: + from veadk.flows.supervisor_auto_flow import SupervisorAutoFlow + + if self.enable_supervisor_flow: + logger.debug(f"Enable supervisor flow for agent: {self.name}") + return SupervisorAutoFlow(supervised_agent=self) + return AutoFlow() + async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: diff --git a/veadk/agents/supervise_agent.py b/veadk/agents/supervise_agent.py new file mode 100644 index 00000000..ec7fcded --- /dev/null +++ b/veadk/agents/supervise_agent.py @@ -0,0 +1,51 @@ +from google.adk.models.llm_request import LlmRequest +from jinja2 import Template + +from veadk import Agent, Runner +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + +instruction = Template("""You are a supervisor of an agent system. The system prompt of worker agent is: + +```system prompt +{{ system_prompt }} +``` + +You should guide the agent to finish task. If you think the history execution is not correct, you should give your advice to the worker agent. If you think the history execution is correct, you should output an empty string. +""") + + +def build_supervisor(supervised_agent: Agent) -> Agent: + custom_instruction = instruction.render(system_prompt=supervised_agent.instruction) + agent = Agent( + name="supervisor", + description="A supervisor for agent execution", + instruction=custom_instruction, + ) + + return agent + + +async def generate_advice(agent: Agent, llm_request: LlmRequest) -> str: + runner = Runner(agent=agent) + + messages = "" + for content in llm_request.contents: + if content and content.parts: + for part in content.parts: + if part.text: + messages += f"{content.role}: {part.text}" + if part.function_call: + messages += f"{content.role}: {part.function_call}" + if part.function_response: + messages += f"{content.role}: {part.function_response}" + + prompt = ( + f"Tools of agent is {llm_request.tools_dict}. History trajectory is: " + + messages + ) + + logger.debug(f"Prompt for supervisor: {prompt}") + + return await runner.run(messages=prompt) diff --git a/veadk/flows/supervisor_auto_flow.py b/veadk/flows/supervisor_auto_flow.py new file mode 100644 index 00000000..4a3c6d0e --- /dev/null +++ b/veadk/flows/supervisor_auto_flow.py @@ -0,0 +1,45 @@ +from typing import AsyncGenerator + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai.types import Content, Part +from typing_extensions import override + +from veadk import Agent +from veadk.agents.supervise_agent import generate_advice +from veadk.flows.supervisor_single_flow import SupervisorSingleFlow +from veadk.utils.logger import get_logger + +logger = get_logger(__name__) + + +class SupervisorAutoFlow(SupervisorSingleFlow): + def __init__(self, supervised_agent: Agent): + super().__init__(supervised_agent) + + @override + async def _call_llm_async( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> AsyncGenerator[LlmResponse, None]: + advice = await generate_advice(self._supervisor, llm_request) + logger.debug(f"Advice from supervisor: {advice}") + + llm_request.contents.append( + Content( + parts=[Part(text=f"Message from your supervisor: {advice}")], + role="user", + ) + ) + + print("====") + print(llm_request) + + async for llm_response in super()._call_llm_async( + invocation_context, llm_request, model_response_event + ): + yield llm_response diff --git a/veadk/flows/supervisor_single_flow.py b/veadk/flows/supervisor_single_flow.py new file mode 100644 index 00000000..b5feb2f8 --- /dev/null +++ b/veadk/flows/supervisor_single_flow.py @@ -0,0 +1,30 @@ +from typing import AsyncGenerator + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events import Event +from google.adk.flows.llm_flows.single_flow import SingleFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from typing_extensions import override + +from veadk import Agent +from veadk.agents.supervise_agent import build_supervisor + + +class SupervisorSingleFlow(SingleFlow): + def __init__(self, supervised_agent: Agent): + self._supervisor = build_supervisor(supervised_agent) + + super().__init__() + + @override + async def _call_llm_async( + self, + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, + ) -> AsyncGenerator[LlmResponse, None]: + async for llm_response in super()._call_llm_async( + invocation_context, llm_request, model_response_event + ): + yield llm_response