diff --git a/backend/examples/cli_research.py b/backend/examples/cli_research.py index a086496b..02d1b8f7 100644 --- a/backend/examples/cli_research.py +++ b/backend/examples/cli_research.py @@ -37,6 +37,29 @@ def main() -> None: messages = result.get("messages", []) if messages: print(messages[-1].content) + + token_records = result.get("token_usage_records", []) + if token_records: + print("\n" + "=" * 80) + print("TOKEN USAGE SUMMARY") + print("=" * 80) + + total_input = 0 + total_output = 0 + + for record in token_records: + print(f"\n{record['node_name'].upper():<20} ({record['model']})") + print(f" Input tokens: {record['input_tokens']:,}") + print(f" Output tokens: {record['output_tokens']:,}") + total_input += record['input_tokens'] + total_output += record['output_tokens'] + + print("\n" + "-" * 80) + print(f"{'TOTAL':<20}") + print(f" Input tokens: {total_input:,}") + print(f" Output tokens: {total_output:,}") + print(f" Total tokens: {(total_input + total_output):,}") + print("=" * 80) if __name__ == "__main__": diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 09eb5988..ee9a47cb 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -28,6 +28,14 @@ dev = ["mypy>=1.11.1", "ruff>=0.6.1"] requires = ["setuptools>=73.0.0", "wheel"] build-backend = "setuptools.build_meta" +# Configure setuptools to find packages in src/ directory to avoid import conflicts with site-packages +# Fixes: ImportError: cannot import name 'extract_token_usage_from_langchain' from 'agent.utils' +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] + [tool.ruff] lint.select = [ "E", # pycodestyle diff --git a/backend/src/agent/configuration.py b/backend/src/agent/configuration.py index e57122d2..c78bd9dd 100644 --- a/backend/src/agent/configuration.py +++ b/backend/src/agent/configuration.py @@ -39,6 +39,13 @@ class Configuration(BaseModel): metadata={"description": "The maximum number of research loops to perform."}, ) + track_token_usage: bool = Field( + default=True, + metadata={ + "description": "Enable token usage tracking for cost monitoring and optimization." + }, + ) + @classmethod def from_runnable_config( cls, config: Optional[RunnableConfig] = None diff --git a/backend/src/agent/graph.py b/backend/src/agent/graph.py index 0f19c3f2..c3a4ea1b 100644 --- a/backend/src/agent/graph.py +++ b/backend/src/agent/graph.py @@ -29,6 +29,7 @@ get_research_topic, insert_citation_markers, resolve_urls, + create_token_usage_record, ) load_dotenv() @@ -78,7 +79,14 @@ def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerati ) # Generate the search queries result = structured_llm.invoke(formatted_prompt) - return {"search_query": result.query} + + update = {"search_query": result.query} + if configurable.track_token_usage: + update["token_usage_records"] = create_token_usage_record( + result, "generate_query", configurable.query_generator_model, is_langchain=True + ) + + return update def continue_to_web_research(state: QueryGenerationState): @@ -129,11 +137,18 @@ def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState: modified_text = insert_citation_markers(response.text, citations) sources_gathered = [item for citation in citations for item in citation["segments"]] - return { + update = { "sources_gathered": sources_gathered, "search_query": [state["search_query"]], "web_research_result": [modified_text], } + + if configurable.track_token_usage: + update["token_usage_records"] = create_token_usage_record( + response, "web_research", configurable.query_generator_model, is_langchain=False + ) + + return update def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState: @@ -171,13 +186,20 @@ def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState: ) result = llm.with_structured_output(Reflection).invoke(formatted_prompt) - return { + update = { "is_sufficient": result.is_sufficient, "knowledge_gap": result.knowledge_gap, "follow_up_queries": result.follow_up_queries, "research_loop_count": state["research_loop_count"], "number_of_ran_queries": len(state["search_query"]), } + + if configurable.track_token_usage: + update["token_usage_records"] = create_token_usage_record( + result, "reflection", reasoning_model, is_langchain=True + ) + + return update def evaluate_research( @@ -259,10 +281,17 @@ def finalize_answer(state: OverallState, config: RunnableConfig): ) unique_sources.append(source) - return { + update = { "messages": [AIMessage(content=result.content)], "sources_gathered": unique_sources, } + + if configurable.track_token_usage: + update["token_usage_records"] = create_token_usage_record( + result, "finalize_answer", reasoning_model, is_langchain=True + ) + + return update # Create our Agent Graph diff --git a/backend/src/agent/state.py b/backend/src/agent/state.py index d5ad4dcd..c9e318c6 100644 --- a/backend/src/agent/state.py +++ b/backend/src/agent/state.py @@ -10,11 +10,21 @@ import operator +class TokenUsageRecord(TypedDict): + """Record of token usage for a single node execution.""" + + node_name: str + input_tokens: int + output_tokens: int + model: str + + class OverallState(TypedDict): messages: Annotated[list, add_messages] search_query: Annotated[list, operator.add] web_research_result: Annotated[list, operator.add] sources_gathered: Annotated[list, operator.add] + token_usage_records: Annotated[list, operator.add] initial_search_query_count: int max_research_loops: int research_loop_count: int diff --git a/backend/src/agent/utils.py b/backend/src/agent/utils.py index d02c8d91..577c7755 100644 --- a/backend/src/agent/utils.py +++ b/backend/src/agent/utils.py @@ -2,6 +2,73 @@ from langchain_core.messages import AnyMessage, AIMessage, HumanMessage +def extract_token_usage_from_langchain(response: Any) -> Dict[str, int]: + """ + Extract token usage from LangChain ChatGoogleGenerativeAI response. + + Args: + response: The response object from LangChain's ChatGoogleGenerativeAI + + Returns: + Dictionary with 'input_tokens' and 'output_tokens' keys + """ + if hasattr(response, "response_metadata"): + usage = response.response_metadata.get("usage_metadata", {}) + return { + "input_tokens": usage.get("prompt_token_count", 0), + "output_tokens": usage.get("candidates_token_count", 0), + } + return {"input_tokens": 0, "output_tokens": 0} + + +def extract_token_usage_from_genai_client(response: Any) -> Dict[str, int]: + """ + Extract token usage from native google.genai.Client response. + + Args: + response: The response object from google.genai.Client + + Returns: + Dictionary with 'input_tokens' and 'output_tokens' keys + """ + if hasattr(response, "usage_metadata"): + return { + "input_tokens": response.usage_metadata.prompt_token_count, + "output_tokens": response.usage_metadata.candidates_token_count, + } + return {"input_tokens": 0, "output_tokens": 0} + + +def create_token_usage_record( + response: Any, node_name: str, model: str, is_langchain: bool = True +) -> List[Dict[str, Any]]: + """ + Create a token usage record for a node execution. + + Args: + response: The response object from either LangChain or google.genai.Client + node_name: Name of the node that generated the response + model: Model name used for the request + is_langchain: True if response is from LangChain, False if from google.genai.Client + + Returns: + List containing a single token usage record dictionary + """ + if is_langchain: + token_usage = extract_token_usage_from_langchain(response) + else: + token_usage = extract_token_usage_from_genai_client(response) + + return [ + { + "node_name": node_name, + "input_tokens": token_usage["input_tokens"], + "output_tokens": token_usage["output_tokens"], + "model": model, + } + ] + + def get_research_topic(messages: List[AnyMessage]) -> str: """ Get the research topic from the messages. diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index d06d4021..8a91b203 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -2,6 +2,7 @@ import { useStream } from "@langchain/langgraph-sdk/react"; import type { Message } from "@langchain/langgraph-sdk"; import { useState, useEffect, useRef, useCallback } from "react"; import { ProcessedEvent } from "@/components/ActivityTimeline"; +import { TokenUsageRecord } from "@/components/TokenUsageDisplay"; import { WelcomeScreen } from "@/components/WelcomeScreen"; import { ChatMessagesView } from "@/components/ChatMessagesView"; import { Button } from "@/components/ui/button"; @@ -13,6 +14,12 @@ export default function App() { const [historicalActivities, setHistoricalActivities] = useState< Record >({}); + const [tokenUsageTimeline, setTokenUsageTimeline] = useState< + TokenUsageRecord[] + >([]); + const [historicalTokenUsage, setHistoricalTokenUsage] = useState< + Record + >({}); const scrollAreaRef = useRef(null); const hasFinalizeEventOccurredRef = useRef(false); const [error, setError] = useState(null); @@ -21,6 +28,7 @@ export default function App() { initial_search_query_count: number; max_research_loops: number; reasoning_model: string; + token_usage_records?: TokenUsageRecord[]; }>({ apiUrl: import.meta.env.DEV ? "http://localhost:2024" @@ -65,6 +73,15 @@ export default function App() { processedEvent!, ]); } + + for (const key in event) { + if (event[key]?.token_usage_records) { + const newRecords = event[key].token_usage_records; + if (Array.isArray(newRecords) && newRecords.length > 0) { + setTokenUsageTimeline((prev) => [...prev, ...newRecords]); + } + } + } }, onError: (error: any) => { setError(error.message); @@ -94,15 +111,20 @@ export default function App() { ...prev, [lastMessage.id!]: [...processedEventsTimeline], })); + setHistoricalTokenUsage((prev) => ({ + ...prev, + [lastMessage.id!]: [...tokenUsageTimeline], + })); } hasFinalizeEventOccurredRef.current = false; } - }, [thread.messages, thread.isLoading, processedEventsTimeline]); + }, [thread.messages, thread.isLoading, processedEventsTimeline, tokenUsageTimeline]); const handleSubmit = useCallback( (submittedInputValue: string, effort: string, model: string) => { if (!submittedInputValue.trim()) return; setProcessedEventsTimeline([]); + setTokenUsageTimeline([]); hasFinalizeEventOccurredRef.current = false; // convert effort to, initial_search_query_count and max_research_loops @@ -181,6 +203,8 @@ export default function App() { onCancel={handleCancel} liveActivityEvents={processedEventsTimeline} historicalActivities={historicalActivities} + liveTokenUsage={tokenUsageTimeline} + historicalTokenUsage={historicalTokenUsage} /> )} diff --git a/frontend/src/components/ChatMessagesView.tsx b/frontend/src/components/ChatMessagesView.tsx index 1a245d88..4c078e9f 100644 --- a/frontend/src/components/ChatMessagesView.tsx +++ b/frontend/src/components/ChatMessagesView.tsx @@ -11,7 +11,11 @@ import { Badge } from "@/components/ui/badge"; import { ActivityTimeline, ProcessedEvent, -} from "@/components/ActivityTimeline"; // Assuming ActivityTimeline is in the same dir or adjust path +} from "@/components/ActivityTimeline"; +import { + TokenUsageDisplay, + TokenUsageRecord, +} from "@/components/TokenUsageDisplay"; // Markdown component props type from former ReportView type MdComponentProps = { @@ -163,6 +167,8 @@ interface AiMessageBubbleProps { message: Message; historicalActivity: ProcessedEvent[] | undefined; liveActivity: ProcessedEvent[] | undefined; + historicalTokenUsage: TokenUsageRecord[] | undefined; + liveTokenUsage: TokenUsageRecord[] | undefined; isLastMessage: boolean; isOverallLoading: boolean; mdComponents: typeof mdComponents; @@ -175,6 +181,8 @@ const AiMessageBubble: React.FC = ({ message, historicalActivity, liveActivity, + historicalTokenUsage, + liveTokenUsage, isLastMessage, isOverallLoading, mdComponents, @@ -185,6 +193,10 @@ const AiMessageBubble: React.FC = ({ const activityForThisBubble = isLastMessage && isOverallLoading ? liveActivity : historicalActivity; const isLiveActivityForThisBubble = isLastMessage && isOverallLoading; + + const tokenUsageForThisBubble = + isLastMessage && isOverallLoading ? liveTokenUsage : historicalTokenUsage; + const isLiveTokenUsageForThisBubble = isLastMessage && isOverallLoading; return (
@@ -196,6 +208,14 @@ const AiMessageBubble: React.FC = ({ />
)} + {tokenUsageForThisBubble && tokenUsageForThisBubble.length > 0 && ( +
+ +
+ )} {typeof message.content === "string" ? message.content @@ -230,6 +250,8 @@ interface ChatMessagesViewProps { onCancel: () => void; liveActivityEvents: ProcessedEvent[]; historicalActivities: Record; + liveTokenUsage: TokenUsageRecord[]; + historicalTokenUsage: Record; } export function ChatMessagesView({ @@ -240,6 +262,8 @@ export function ChatMessagesView({ onCancel, liveActivityEvents, historicalActivities, + liveTokenUsage, + historicalTokenUsage, }: ChatMessagesViewProps) { const [copiedMessageId, setCopiedMessageId] = useState(null); @@ -275,6 +299,8 @@ export function ChatMessagesView({ message={message} historicalActivity={historicalActivities[message.id!]} liveActivity={liveActivityEvents} // Pass global live events + historicalTokenUsage={historicalTokenUsage[message.id!]} + liveTokenUsage={liveTokenUsage} // Pass global live token usage isLastMessage={isLast} isOverallLoading={isLoading} // Pass global loading state mdComponents={mdComponents} @@ -294,11 +320,17 @@ export function ChatMessagesView({ {/* AI message row structure */}
{liveActivityEvents.length > 0 ? ( -
+
+ {liveTokenUsage.length > 0 && ( + + )}
) : (
diff --git a/frontend/src/components/TokenUsageDisplay.tsx b/frontend/src/components/TokenUsageDisplay.tsx new file mode 100644 index 00000000..231c0302 --- /dev/null +++ b/frontend/src/components/TokenUsageDisplay.tsx @@ -0,0 +1,211 @@ +import { + Card, + CardContent, + CardDescription, + CardHeader, +} from "@/components/ui/card"; +import { ScrollArea } from "@/components/ui/scroll-area"; +import { + Coins, + TrendingUp, + ChevronDown, + ChevronUp, + Sparkles, + Search, + Brain, + Pen, +} from "lucide-react"; +import { useState } from "react"; +import { Badge } from "@/components/ui/badge"; + +export interface TokenUsageRecord { + node_name: string; + input_tokens: number; + output_tokens: number; + model: string; +} + +interface TokenUsageDisplayProps { + tokenRecords: TokenUsageRecord[]; + isLoading: boolean; +} + +const MODEL_PRICING: Record = { + "gemini-2.0-flash": { input: 0.075, output: 0.3 }, + "gemini-2.5-flash": { input: 0.075, output: 0.3 }, + "gemini-2.5-flash-preview-04-17": { input: 0.075, output: 0.3 }, + "gemini-2.5-pro": { input: 1.25, output: 5.0 }, + "gemini-2.5-pro-preview-05-06": { input: 1.25, output: 5.0 }, + default: { input: 0.075, output: 0.3 }, +}; + +const getNodeIcon = (nodeName: string) => { + if (nodeName.toLowerCase().includes("generate")) { + return ; + } else if (nodeName.toLowerCase().includes("research")) { + return ; + } else if (nodeName.toLowerCase().includes("reflection")) { + return ; + } else if (nodeName.toLowerCase().includes("finalize")) { + return ; + } + return ; +}; + +const calculateCost = ( + inputTokens: number, + outputTokens: number, + model: string +): number => { + const pricing = MODEL_PRICING[model] || MODEL_PRICING.default; + const inputCost = (inputTokens / 1000000) * pricing.input; + const outputCost = (outputTokens / 1000000) * pricing.output; + return inputCost + outputCost; +}; + +export function TokenUsageDisplay({ + tokenRecords, + isLoading, +}: TokenUsageDisplayProps) { + const [isCollapsed, setIsCollapsed] = useState(false); + + const totals = tokenRecords.reduce( + (acc, record) => { + acc.inputTokens += record.input_tokens; + acc.outputTokens += record.output_tokens; + acc.totalTokens += record.input_tokens + record.output_tokens; + acc.cost += calculateCost( + record.input_tokens, + record.output_tokens, + record.model + ); + return acc; + }, + { inputTokens: 0, outputTokens: 0, totalTokens: 0, cost: 0 } + ); + + if (tokenRecords.length === 0 && !isLoading) { + return null; + } + + return ( + + + +
setIsCollapsed(!isCollapsed)} + > + + Token Usage + {totals.totalTokens > 0 && ( + + {totals.totalTokens.toLocaleString()} tokens + + )} + {totals.cost > 0 && ( + + ~${totals.cost.toFixed(4)} + + )} + {isCollapsed ? ( + + ) : ( + + )} +
+
+
+ {!isCollapsed && ( + + + {tokenRecords.length > 0 ? ( +
+
+ {tokenRecords.map((record, index) => { + const nodeCost = calculateCost( + record.input_tokens, + record.output_tokens, + record.model + ); + return ( +
+
+ {getNodeIcon(record.node_name)} +
+
+
+

+ {record.node_name + .split("_") + .map( + (word) => + word.charAt(0).toUpperCase() + word.slice(1) + ) + .join(" ")} +

+ + {record.model.replace("gemini-", "")} + +
+
+ + ↓ {record.input_tokens.toLocaleString()} in + + + ↑ {record.output_tokens.toLocaleString()} out + + {nodeCost > 0 && ( + + ${nodeCost.toFixed(4)} + + )} +
+
+
+ ); + })} +
+ + {tokenRecords.length > 1 && ( +
+
+
+ + Total +
+
+ + {totals.totalTokens.toLocaleString()} tokens + + {totals.cost > 0 && ( + + ${totals.cost.toFixed(4)} + + )} +
+
+
+ )} +
+ ) : ( +
+

+ {isLoading + ? "Calculating token usage..." + : "No token data available"} +

+
+ )} +
+
+ )} +
+ ); +}