Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions backend/examples/cli_research.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ def main() -> None:
messages = result.get("messages", [])
if messages:
print(messages[-1].content)

token_records = result.get("token_usage_records", [])
if token_records:
print("\n" + "=" * 80)
print("TOKEN USAGE SUMMARY")
print("=" * 80)

total_input = 0
total_output = 0

for record in token_records:
print(f"\n{record['node_name'].upper():<20} ({record['model']})")
print(f" Input tokens: {record['input_tokens']:,}")
print(f" Output tokens: {record['output_tokens']:,}")
total_input += record['input_tokens']
total_output += record['output_tokens']

print("\n" + "-" * 80)
print(f"{'TOTAL':<20}")
print(f" Input tokens: {total_input:,}")
print(f" Output tokens: {total_output:,}")
print(f" Total tokens: {(total_input + total_output):,}")
print("=" * 80)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ dev = ["mypy>=1.11.1", "ruff>=0.6.1"]
requires = ["setuptools>=73.0.0", "wheel"]
build-backend = "setuptools.build_meta"

# Configure setuptools to find packages in src/ directory to avoid import conflicts with site-packages
# Fixes: ImportError: cannot import name 'extract_token_usage_from_langchain' from 'agent.utils'
[tool.setuptools]
package-dir = {"" = "src"}

[tool.setuptools.packages.find]
where = ["src"]

[tool.ruff]
lint.select = [
"E", # pycodestyle
Expand Down
7 changes: 7 additions & 0 deletions backend/src/agent/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class Configuration(BaseModel):
metadata={"description": "The maximum number of research loops to perform."},
)

track_token_usage: bool = Field(
default=True,
metadata={
"description": "Enable token usage tracking for cost monitoring and optimization."
},
)

@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
Expand Down
37 changes: 33 additions & 4 deletions backend/src/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_research_topic,
insert_citation_markers,
resolve_urls,
create_token_usage_record,
)

load_dotenv()
Expand Down Expand Up @@ -78,7 +79,14 @@ def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerati
)
# Generate the search queries
result = structured_llm.invoke(formatted_prompt)
return {"search_query": result.query}

update = {"search_query": result.query}
if configurable.track_token_usage:
update["token_usage_records"] = create_token_usage_record(
result, "generate_query", configurable.query_generator_model, is_langchain=True
)

return update


def continue_to_web_research(state: QueryGenerationState):
Expand Down Expand Up @@ -129,11 +137,18 @@ def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
modified_text = insert_citation_markers(response.text, citations)
sources_gathered = [item for citation in citations for item in citation["segments"]]

return {
update = {
"sources_gathered": sources_gathered,
"search_query": [state["search_query"]],
"web_research_result": [modified_text],
}

if configurable.track_token_usage:
update["token_usage_records"] = create_token_usage_record(
response, "web_research", configurable.query_generator_model, is_langchain=False
)

return update


def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
Expand Down Expand Up @@ -171,13 +186,20 @@ def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
)
result = llm.with_structured_output(Reflection).invoke(formatted_prompt)

return {
update = {
"is_sufficient": result.is_sufficient,
"knowledge_gap": result.knowledge_gap,
"follow_up_queries": result.follow_up_queries,
"research_loop_count": state["research_loop_count"],
"number_of_ran_queries": len(state["search_query"]),
}

if configurable.track_token_usage:
update["token_usage_records"] = create_token_usage_record(
result, "reflection", reasoning_model, is_langchain=True
)

return update


def evaluate_research(
Expand Down Expand Up @@ -259,10 +281,17 @@ def finalize_answer(state: OverallState, config: RunnableConfig):
)
unique_sources.append(source)

return {
update = {
"messages": [AIMessage(content=result.content)],
"sources_gathered": unique_sources,
}

if configurable.track_token_usage:
update["token_usage_records"] = create_token_usage_record(
result, "finalize_answer", reasoning_model, is_langchain=True
)

return update


# Create our Agent Graph
Expand Down
10 changes: 10 additions & 0 deletions backend/src/agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@
import operator


class TokenUsageRecord(TypedDict):
"""Record of token usage for a single node execution."""

node_name: str
input_tokens: int
output_tokens: int
model: str


class OverallState(TypedDict):
messages: Annotated[list, add_messages]
search_query: Annotated[list, operator.add]
web_research_result: Annotated[list, operator.add]
sources_gathered: Annotated[list, operator.add]
token_usage_records: Annotated[list, operator.add]
initial_search_query_count: int
max_research_loops: int
research_loop_count: int
Expand Down
67 changes: 67 additions & 0 deletions backend/src/agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,73 @@
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage


def extract_token_usage_from_langchain(response: Any) -> Dict[str, int]:
"""
Extract token usage from LangChain ChatGoogleGenerativeAI response.

Args:
response: The response object from LangChain's ChatGoogleGenerativeAI

Returns:
Dictionary with 'input_tokens' and 'output_tokens' keys
"""
if hasattr(response, "response_metadata"):
usage = response.response_metadata.get("usage_metadata", {})
return {
"input_tokens": usage.get("prompt_token_count", 0),
"output_tokens": usage.get("candidates_token_count", 0),
}
return {"input_tokens": 0, "output_tokens": 0}


def extract_token_usage_from_genai_client(response: Any) -> Dict[str, int]:
"""
Extract token usage from native google.genai.Client response.

Args:
response: The response object from google.genai.Client

Returns:
Dictionary with 'input_tokens' and 'output_tokens' keys
"""
if hasattr(response, "usage_metadata"):
return {
"input_tokens": response.usage_metadata.prompt_token_count,
"output_tokens": response.usage_metadata.candidates_token_count,
}
return {"input_tokens": 0, "output_tokens": 0}


def create_token_usage_record(
response: Any, node_name: str, model: str, is_langchain: bool = True
) -> List[Dict[str, Any]]:
"""
Create a token usage record for a node execution.

Args:
response: The response object from either LangChain or google.genai.Client
node_name: Name of the node that generated the response
model: Model name used for the request
is_langchain: True if response is from LangChain, False if from google.genai.Client

Returns:
List containing a single token usage record dictionary
"""
if is_langchain:
token_usage = extract_token_usage_from_langchain(response)
else:
token_usage = extract_token_usage_from_genai_client(response)

return [
{
"node_name": node_name,
"input_tokens": token_usage["input_tokens"],
"output_tokens": token_usage["output_tokens"],
"model": model,
}
]


def get_research_topic(messages: List[AnyMessage]) -> str:
"""
Get the research topic from the messages.
Expand Down
26 changes: 25 additions & 1 deletion frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { useStream } from "@langchain/langgraph-sdk/react";
import type { Message } from "@langchain/langgraph-sdk";
import { useState, useEffect, useRef, useCallback } from "react";
import { ProcessedEvent } from "@/components/ActivityTimeline";
import { TokenUsageRecord } from "@/components/TokenUsageDisplay";
import { WelcomeScreen } from "@/components/WelcomeScreen";
import { ChatMessagesView } from "@/components/ChatMessagesView";
import { Button } from "@/components/ui/button";
Expand All @@ -13,6 +14,12 @@ export default function App() {
const [historicalActivities, setHistoricalActivities] = useState<
Record<string, ProcessedEvent[]>
>({});
const [tokenUsageTimeline, setTokenUsageTimeline] = useState<
TokenUsageRecord[]
>([]);
const [historicalTokenUsage, setHistoricalTokenUsage] = useState<
Record<string, TokenUsageRecord[]>
>({});
const scrollAreaRef = useRef<HTMLDivElement>(null);
const hasFinalizeEventOccurredRef = useRef(false);
const [error, setError] = useState<string | null>(null);
Expand All @@ -21,6 +28,7 @@ export default function App() {
initial_search_query_count: number;
max_research_loops: number;
reasoning_model: string;
token_usage_records?: TokenUsageRecord[];
}>({
apiUrl: import.meta.env.DEV
? "http://localhost:2024"
Expand Down Expand Up @@ -65,6 +73,15 @@ export default function App() {
processedEvent!,
]);
}

for (const key in event) {
if (event[key]?.token_usage_records) {
const newRecords = event[key].token_usage_records;
if (Array.isArray(newRecords) && newRecords.length > 0) {
setTokenUsageTimeline((prev) => [...prev, ...newRecords]);
}
}
}
},
onError: (error: any) => {
setError(error.message);
Expand Down Expand Up @@ -94,15 +111,20 @@ export default function App() {
...prev,
[lastMessage.id!]: [...processedEventsTimeline],
}));
setHistoricalTokenUsage((prev) => ({
...prev,
[lastMessage.id!]: [...tokenUsageTimeline],
}));
}
hasFinalizeEventOccurredRef.current = false;
}
}, [thread.messages, thread.isLoading, processedEventsTimeline]);
}, [thread.messages, thread.isLoading, processedEventsTimeline, tokenUsageTimeline]);

const handleSubmit = useCallback(
(submittedInputValue: string, effort: string, model: string) => {
if (!submittedInputValue.trim()) return;
setProcessedEventsTimeline([]);
setTokenUsageTimeline([]);
hasFinalizeEventOccurredRef.current = false;

// convert effort to, initial_search_query_count and max_research_loops
Expand Down Expand Up @@ -181,6 +203,8 @@ export default function App() {
onCancel={handleCancel}
liveActivityEvents={processedEventsTimeline}
historicalActivities={historicalActivities}
liveTokenUsage={tokenUsageTimeline}
historicalTokenUsage={historicalTokenUsage}
/>
)}
</main>
Expand Down
Loading