Skip to content

Commit c3777bc

Browse files
feat(llm): add token usage tracking and statistics for LLM providers and resolve the conflicts and modify files in src/llm/providers/archived to be consistent with the current main branch. (#92)
* support token usage stats --------- Co-authored-by: Einsiedler <[email protected]>
1 parent aa60684 commit c3777bc

File tree

5 files changed

+150
-1
lines changed

5 files changed

+150
-1
lines changed

src/core/orchestrator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,9 @@ async def run_sub_agent(
391391
# Start new sub-agent session
392392
self.task_log.start_sub_agent_session(sub_agent_name, task_description)
393393

394+
# Reset sub-agent usage stats for independent tracking
395+
self.sub_agent_llm_client.reset_usage_stats()
396+
394397
# Simplified initial user content (no file attachments)
395398
initial_user_content = [{"type": "text", "text": task_description}]
396399
message_history = [{"role": "user", "content": initial_user_content}]
@@ -667,6 +670,14 @@ async def run_sub_agent(
667670
] = {"system_prompt": system_prompt, "message_history": message_history} # type: ignore
668671
self.task_log.save()
669672

673+
# Record sub-agent cumulative usage
674+
usage_log = self.sub_agent_llm_client.get_usage_log()
675+
self.task_log.log_step(
676+
"usage_calculation",
677+
usage_log,
678+
metadata={"session_id": self.task_log.current_sub_agent_session_id},
679+
)
680+
670681
self.task_log.end_sub_agent_session(sub_agent_name)
671682
self.task_log.log_step(
672683
"sub_agent_completed", f"Sub agent {sub_agent_name} completed", "info"
@@ -688,6 +699,9 @@ async def run_main_agent(
688699
if task_file_name:
689700
logger.debug(f"Associated File: {task_file_name}")
690701

702+
# Reset main agent usage stats for independent tracking
703+
self.llm_client.reset_usage_stats()
704+
691705
# 1. Process input
692706
initial_user_content, task_description = process_input(
693707
task_description, task_file_name
@@ -1103,6 +1117,14 @@ async def run_main_agent(
11031117
"task_completed", f"Main agent task {task_id} completed successfully"
11041118
)
11051119

1120+
# Record main agent cumulative usage
1121+
usage_log = self.llm_client.get_usage_log()
1122+
self.task_log.log_step(
1123+
"usage_calculation",
1124+
usage_log,
1125+
metadata={"session_id": "main_agent"},
1126+
)
1127+
11061128
if "browsecomp-zh" in self.cfg.benchmark.name:
11071129
return final_summary, final_summary
11081130
else:

src/llm/provider_client_base.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class LLMProviderClientBase(ABC):
3434

3535
# post_init
3636
client: Any = dataclasses.field(init=False)
37+
# Usage tracking - cumulative for each agent session
38+
total_input_tokens: int = dataclasses.field(init=False, default=0)
39+
total_input_cached_tokens: int = dataclasses.field(init=False, default=0)
40+
total_output_tokens: int = dataclasses.field(init=False, default=0)
41+
total_output_reasoning_tokens: int = dataclasses.field(init=False, default=0)
3742

3843
def __post_init__(self):
3944
# Explicitly assign from cfg object
@@ -196,6 +201,21 @@ async def create_message(
196201
tool_definitions,
197202
keep_tool_result=keep_tool_result,
198203
)
204+
205+
# Accumulate usage for agent session
206+
if response:
207+
try:
208+
usage = self._extract_usage_from_response(response)
209+
if usage:
210+
self.total_input_tokens += usage.get("input_tokens", 0)
211+
self.total_input_cached_tokens += usage.get("cached_tokens", 0)
212+
self.total_output_tokens += usage.get("output_tokens", 0)
213+
self.total_output_reasoning_tokens += usage.get(
214+
"reasoning_tokens", 0
215+
)
216+
except Exception as e:
217+
logger.warning(f"Failed to accumulate usage: {e}")
218+
199219
return response
200220

201221
@staticmethod
@@ -315,3 +335,54 @@ def handle_max_turns_reached_summary_prompt(
315335
self, message_history: list[dict[str, Any]], summary_prompt: str
316336
):
317337
raise NotImplementedError("must implement in subclass")
338+
339+
def _extract_usage_from_response(self, response):
340+
"""Default Extract usage - OpenAI Chat Completions format"""
341+
if not hasattr(response, "usage"):
342+
return {
343+
"input_tokens": 0,
344+
"cached_tokens": 0,
345+
"output_tokens": 0,
346+
"reasoning_tokens": 0,
347+
}
348+
349+
usage = response.usage
350+
prompt_tokens_details = getattr(usage, "prompt_tokens_details", {}) or {}
351+
if hasattr(prompt_tokens_details, "to_dict"):
352+
prompt_tokens_details = prompt_tokens_details.to_dict()
353+
completion_tokens_details = (
354+
getattr(usage, "completion_tokens_details", {}) or {}
355+
)
356+
if hasattr(completion_tokens_details, "to_dict"):
357+
completion_tokens_details = completion_tokens_details.to_dict()
358+
359+
usage_dict = {
360+
"input_tokens": getattr(usage, "prompt_tokens", 0),
361+
"cached_tokens": prompt_tokens_details.get("cached_tokens", 0),
362+
"output_tokens": getattr(usage, "completion_tokens", 0),
363+
"reasoning_tokens": completion_tokens_details.get("reasoning_tokens", 0),
364+
}
365+
366+
return usage_dict
367+
368+
def get_usage_log(self) -> str:
369+
"""Get cumulative usage for current agent session as formatted string"""
370+
# Format: [Provider | Model] Total Input: X, Cache Input: Y, Output: Z, ...
371+
provider_model = f"[{self.provider_class} | {self.model_name}]"
372+
input_uncached = self.total_input_tokens - self.total_input_cached_tokens
373+
output_response = self.total_output_tokens - self.total_output_reasoning_tokens
374+
total_tokens = self.total_input_tokens + self.total_output_tokens
375+
376+
return (
377+
f"Usage log: {provider_model}, "
378+
f"Total Input: {self.total_input_tokens} (Cached: {self.total_input_cached_tokens}, Uncached: {input_uncached}), "
379+
f"Total Output: {self.total_output_tokens} (Reasoning: {self.total_output_reasoning_tokens}, Response: {output_response}), "
380+
f"Total Tokens: {total_tokens}"
381+
)
382+
383+
def reset_usage_stats(self):
384+
"""Reset usage stats for new agent session"""
385+
self.total_input_tokens = 0
386+
self.total_input_cached_tokens = 0
387+
self.total_output_tokens = 0
388+
self.total_output_reasoning_tokens = 0

src/llm/providers/archived/gpt_openai_response_client.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def _create_message(
9393
response = self._convert_response_to_serializable(response)
9494

9595
# Update token count
96-
self._update_token_usage(response.get("usage", None))
96+
# self._update_token_usage(response.get("usage", None))
9797
logger.debug(
9898
f"LLM Response API call status: {response.get('error', 'N/A')}"
9999
)
@@ -269,3 +269,30 @@ def _convert_response_to_serializable(self, response):
269269
}
270270

271271
return serializable_response
272+
273+
def _extract_usage_from_response(self, response):
274+
"""Extract usage - OpenAI Responses API format"""
275+
if not response or not response.get("usage"):
276+
return {
277+
"input_tokens": 0,
278+
"cached_tokens": 0,
279+
"output_tokens": 0,
280+
"reasoning_tokens": 0,
281+
}
282+
283+
usage = response.get("usage", {}) or {}
284+
input_tokens_details = usage.get("input_tokens_details", {}) or {}
285+
if hasattr(input_tokens_details, "to_dict"):
286+
input_tokens_details = input_tokens_details.to_dict()
287+
output_tokens_details = usage.get("output_tokens_details", {}) or {}
288+
if hasattr(output_tokens_details, "to_dict"):
289+
output_tokens_details = output_tokens_details.to_dict()
290+
291+
usage_dict = {
292+
"input_tokens": usage.get("input_tokens", 0),
293+
"cached_tokens": input_tokens_details.get("cached_tokens", 0),
294+
"output_tokens": usage.get("output_tokens", 0),
295+
"reasoning_tokens": output_tokens_details.get("reasoning_tokens", 0),
296+
}
297+
298+
return usage_dict

src/llm/providers/claude_anthropic_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,33 @@ def handle_max_turns_reached_summary_prompt(self, message_history, summary_promp
185185
else:
186186
return summary_prompt
187187

188+
def _extract_usage_from_response(self, response):
189+
"""Extract usage - Anthropic format"""
190+
if not hasattr(response, "usage"):
191+
return {
192+
"input_tokens": 0,
193+
"cached_tokens": 0,
194+
"output_tokens": 0,
195+
"reasoning_tokens": 0,
196+
}
197+
198+
usage = response.usage
199+
cache_creation_input_tokens = getattr(usage, "cache_creation_input_tokens", 0)
200+
cache_read_input_tokens = getattr(usage, "cache_read_input_tokens", 0)
201+
input_tokens = getattr(usage, "input_tokens", 0)
202+
output_tokens = getattr(usage, "output_tokens", 0)
203+
204+
usage_dict = {
205+
"input_tokens": cache_creation_input_tokens
206+
+ cache_read_input_tokens
207+
+ input_tokens,
208+
"cached_tokens": cache_read_input_tokens,
209+
"output_tokens": output_tokens,
210+
"reasoning_tokens": 0,
211+
}
212+
213+
return usage_dict
214+
188215
def _apply_cache_control(self, messages):
189216
"""Apply cache control to the last user message and system message (if applicable)"""
190217
cached_messages = []

src/llm/providers/claude_openrouter_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ async def _create_message(
133133
if self.repetition_penalty != 1.0:
134134
extra_body["repetition_penalty"] = self.repetition_penalty
135135

136+
extra_body["usage"] = {"include": True}
137+
136138
params = {
137139
"model": self.model_name,
138140
"temperature": temperature,

0 commit comments

Comments
 (0)