Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import os
from typing import Optional, Union, AsyncGenerator
from typing import AsyncGenerator, Optional, Union

# If user didn't set LITELLM_LOCAL_MODEL_COST_MAP, set it to True
# to enable local model cost map.
Expand All @@ -24,12 +24,15 @@
if not os.getenv("LITELLM_LOCAL_MODEL_COST_MAP"):
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

from google.adk.agents import LlmAgent, RunConfig, InvocationContext
from google.adk.agents import InvocationContext, LlmAgent, RunConfig
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
from google.adk.agents.run_config import StreamingMode
from google.adk.events import Event, EventActions
from google.adk.flows.llm_flows.auto_flow import AutoFlow
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.adk.flows.llm_flows.single_flow import SingleFlow
from google.adk.models.lite_llm import LiteLlm
from google.adk.runners import Runner
from google.genai import types
Expand All @@ -53,8 +56,8 @@
from veadk.prompts.prompt_manager import BasePromptManager
from veadk.tracing.base_tracer import BaseTracer
from veadk.utils.logger import get_logger
from veadk.utils.patches import patch_asyncio, patch_tracer
from veadk.utils.misc import check_litellm_version
from veadk.utils.patches import patch_asyncio, patch_tracer
from veadk.version import VERSION

patch_tracer()
Expand Down Expand Up @@ -118,6 +121,8 @@ class Agent(LlmAgent):

enable_responses: bool = False

enable_supervisor_flow: bool = False

context_cache_config: Optional[ContextCacheConfig] = None

run_processor: Optional[BaseRunProcessor] = Field(default=None, exclude=True)
Expand Down Expand Up @@ -292,6 +297,28 @@ def model_post_init(self, __context: Any) -> None:
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}"
)

@property
def _llm_flow(self) -> BaseLlmFlow:
if (
self.disallow_transfer_to_parent
and self.disallow_transfer_to_peers
and not self.sub_agents
):
from veadk.flows.supervisor_single_flow import SupervisorSingleFlow

if self.enable_supervisor_flow:
logger.debug(f"Enable supervisor flow for agent: {self.name}")
return SupervisorSingleFlow(supervised_agent=self)
else:
return SingleFlow()
else:
from veadk.flows.supervisor_auto_flow import SupervisorAutoFlow

if self.enable_supervisor_flow:
logger.debug(f"Enable supervisor flow for agent: {self.name}")
return SupervisorAutoFlow(supervised_agent=self)
return AutoFlow()

async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
Expand Down
51 changes: 51 additions & 0 deletions veadk/agents/supervise_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from google.adk.models.llm_request import LlmRequest
from jinja2 import Template

from veadk import Agent, Runner
from veadk.utils.logger import get_logger

logger = get_logger(__name__)

instruction = Template("""You are a supervisor of an agent system. The system prompt of worker agent is:

```system prompt
{{ system_prompt }}
```

You should guide the agent to finish task. If you think the history execution is not correct, you should give your advice to the worker agent. If you think the history execution is correct, you should output an empty string.
""")


def build_supervisor(supervised_agent: Agent) -> Agent:
custom_instruction = instruction.render(system_prompt=supervised_agent.instruction)
agent = Agent(
name="supervisor",
description="A supervisor for agent execution",
instruction=custom_instruction,
)

return agent


async def generate_advice(agent: Agent, llm_request: LlmRequest) -> str:
runner = Runner(agent=agent)

messages = ""
for content in llm_request.contents:
if content and content.parts:
for part in content.parts:
if part.text:
messages += f"{content.role}: {part.text}"
if part.function_call:
messages += f"{content.role}: {part.function_call}"
if part.function_response:
messages += f"{content.role}: {part.function_response}"

prompt = (
f"Tools of agent is {llm_request.tools_dict}. History trajectory is: "
+ messages
)

logger.debug(f"Prompt for supervisor: {prompt}")

return await runner.run(messages=prompt)
45 changes: 45 additions & 0 deletions veadk/flows/supervisor_auto_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import AsyncGenerator

from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.genai.types import Content, Part
from typing_extensions import override

from veadk import Agent
from veadk.agents.supervise_agent import generate_advice
from veadk.flows.supervisor_single_flow import SupervisorSingleFlow
from veadk.utils.logger import get_logger

logger = get_logger(__name__)


class SupervisorAutoFlow(SupervisorSingleFlow):
def __init__(self, supervised_agent: Agent):
super().__init__(supervised_agent)

@override
async def _call_llm_async(
self,
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
advice = await generate_advice(self._supervisor, llm_request)
logger.debug(f"Advice from supervisor: {advice}")

llm_request.contents.append(
Content(
parts=[Part(text=f"Message from your supervisor: {advice}")],
role="user",
)
)

print("====")
print(llm_request)

async for llm_response in super()._call_llm_async(
invocation_context, llm_request, model_response_event
):
yield llm_response
30 changes: 30 additions & 0 deletions veadk/flows/supervisor_single_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import AsyncGenerator

from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event
from google.adk.flows.llm_flows.single_flow import SingleFlow
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from typing_extensions import override

from veadk import Agent
from veadk.agents.supervise_agent import build_supervisor


class SupervisorSingleFlow(SingleFlow):
def __init__(self, supervised_agent: Agent):
self._supervisor = build_supervisor(supervised_agent)

super().__init__()

@override
async def _call_llm_async(
self,
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
async for llm_response in super()._call_llm_async(
invocation_context, llm_request, model_response_event
):
yield llm_response