Skip to content

Commit fd90189

Browse files
support token usage stats
1 parent 3d49bd8 commit fd90189

9 files changed

+163
-20
lines changed

src/core/orchestrator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,9 @@ async def run_sub_agent(
386386

387387
# Start new sub-agent session
388388
self.task_log.start_sub_agent_session(sub_agent_name, task_description)
389+
390+
# Reset sub-agent usage stats for independent tracking
391+
self.sub_agent_llm_client.reset_usage_stats()
389392

390393
# Simplified initial user content (no file attachments)
391394
initial_user_content = [{"type": "text", "text": task_description}]
@@ -661,6 +664,14 @@ async def run_sub_agent(
661664
] = {"system_prompt": system_prompt, "message_history": message_history} # type: ignore
662665
self.task_log.save()
663666

667+
# Record sub-agent cumulative usage
668+
usage_log = self.sub_agent_llm_client.get_usage_log()
669+
self.task_log.log_step(
670+
"usage_calculation",
671+
usage_log,
672+
metadata={"session_id": self.task_log.current_sub_agent_session_id},
673+
)
674+
664675
self.task_log.end_sub_agent_session(sub_agent_name)
665676
self.task_log.log_step(
666677
"sub_agent_completed", f"Sub agent {sub_agent_name} completed", "info"
@@ -682,6 +693,9 @@ async def run_main_agent(
682693
if task_file_name:
683694
logger.debug(f"Associated File: {task_file_name}")
684695

696+
# Reset main agent usage stats for independent tracking
697+
self.llm_client.reset_usage_stats()
698+
685699
# 1. Process input
686700
initial_user_content, task_description = process_input(
687701
task_description, task_file_name
@@ -1089,6 +1103,14 @@ async def run_main_agent(
10891103
"task_completed", f"Main agent task {task_id} completed successfully"
10901104
)
10911105

1106+
# Record main agent cumulative usage
1107+
usage_log = self.llm_client.get_usage_log()
1108+
self.task_log.log_step(
1109+
"usage_calculation",
1110+
usage_log,
1111+
metadata={"session_id": "main_agent"},
1112+
)
1113+
10921114
if "browsecomp-zh" in self.cfg.benchmark.name:
10931115
return final_summary, final_summary
10941116
else:

src/llm/provider_client_base.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class LLMProviderClientBase(ABC):
3434

3535
# post_init
3636
client: Any = dataclasses.field(init=False)
37+
# Usage tracking - cumulative for each agent session
38+
total_input_tokens: int = dataclasses.field(init=False, default=0)
39+
total_input_cached_tokens: int = dataclasses.field(init=False, default=0)
40+
total_output_tokens: int = dataclasses.field(init=False, default=0)
41+
total_output_reasoning_tokens: int = dataclasses.field(init=False, default=0)
3742

3843
def __post_init__(self):
3944
# Explicitly assign from cfg object
@@ -195,6 +200,19 @@ async def create_message(
195200
tool_definitions,
196201
keep_tool_result=keep_tool_result,
197202
)
203+
204+
# Accumulate usage for agent session
205+
if response:
206+
try:
207+
usage = self._extract_usage_from_response(response)
208+
if usage:
209+
self.total_input_tokens += usage.get("input_tokens", 0)
210+
self.total_input_cached_tokens += usage.get("cached_tokens", 0)
211+
self.total_output_tokens += usage.get("output_tokens", 0)
212+
self.total_output_reasoning_tokens += usage.get("reasoning_tokens", 0)
213+
except Exception as e:
214+
logger.warning(f"Failed to accumulate usage: {e}")
215+
198216
return response
199217

200218
@staticmethod
@@ -314,3 +332,50 @@ def handle_max_turns_reached_summary_prompt(
314332
self, message_history: list[dict[str, Any]], summary_prompt: str
315333
):
316334
raise NotImplementedError("must implement in subclass")
335+
336+
def _extract_usage_from_response(self, response):
337+
"""Default Extract usage - OpenAI Chat Completions format"""
338+
if not hasattr(response, 'usage'):
339+
return {
340+
"input_tokens": 0,
341+
"cached_tokens": 0,
342+
"output_tokens": 0,
343+
"reasoning_tokens": 0
344+
}
345+
346+
usage = response.usage
347+
prompt_tokens_details = getattr(usage, 'prompt_tokens_details', {}) or {}
348+
if hasattr(prompt_tokens_details, "to_dict"):
349+
prompt_tokens_details = prompt_tokens_details.to_dict()
350+
completion_tokens_details = getattr(usage, 'completion_tokens_details', {}) or {}
351+
if hasattr(completion_tokens_details, "to_dict"):
352+
completion_tokens_details = completion_tokens_details.to_dict()
353+
354+
usage_dict = {
355+
"input_tokens": getattr(usage, 'prompt_tokens', 0),
356+
"cached_tokens": prompt_tokens_details.get('cached_tokens', 0),
357+
"output_tokens": getattr(usage, 'completion_tokens', 0),
358+
"reasoning_tokens": completion_tokens_details.get('reasoning_tokens', 0)
359+
}
360+
361+
return usage_dict
362+
363+
def get_usage_log(self) -> str:
364+
"""Get cumulative usage for current agent session as formatted string"""
365+
# Format: [Provider | Model] Total Input: X, Cache Input: Y, Output: Z, ...
366+
provider_model = f"[{self.provider_class} | {self.model_name}]"
367+
input_uncached = self.total_input_tokens - self.total_input_cached_tokens
368+
output_response = self.total_output_tokens - self.total_output_reasoning_tokens
369+
total_tokens = self.total_input_tokens + self.total_output_tokens
370+
371+
return (f"Usage log: {provider_model}, "
372+
f"Total Input: {self.total_input_tokens} (Cached: {self.total_input_cached_tokens}, Uncached: {input_uncached}), "
373+
f"Total Output: {self.total_output_tokens} (Reasoning: {self.total_output_reasoning_tokens}, Response: {output_response}), "
374+
f"Total Tokens: {total_tokens}")
375+
376+
def reset_usage_stats(self):
377+
"""Reset usage stats for new agent session"""
378+
self.total_input_tokens = 0
379+
self.total_input_cached_tokens = 0
380+
self.total_output_tokens = 0
381+
self.total_output_reasoning_tokens = 0

src/llm/providers/claude_anthropic_client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __post_init__(self):
2929

3030
def _create_client(self, config: DictConfig):
3131
"""Create Anthropic client"""
32-
api_key = config.env.anthropic_api_key
32+
api_key = self.cfg.llm.anthropic_api_key
3333

3434
if self.async_client:
3535
return AsyncAnthropic(
@@ -183,6 +183,31 @@ def handle_max_turns_reached_summary_prompt(self, message_history, summary_promp
183183
else:
184184
return summary_prompt
185185

186+
def _extract_usage_from_response(self, response):
187+
"""Extract usage - Anthropic format"""
188+
if not hasattr(response, 'usage'):
189+
return {
190+
"input_tokens": 0,
191+
"cached_tokens": 0,
192+
"output_tokens": 0,
193+
"reasoning_tokens": 0
194+
}
195+
196+
usage = response.usage
197+
cache_creation_input_tokens = getattr(usage, 'cache_creation_input_tokens', 0)
198+
cache_read_input_tokens = getattr(usage, 'cache_read_input_tokens', 0)
199+
input_tokens = getattr(usage, 'input_tokens', 0)
200+
output_tokens = getattr(usage, 'output_tokens', 0)
201+
202+
usage_dict = {
203+
"input_tokens": cache_creation_input_tokens + cache_read_input_tokens + input_tokens,
204+
"cached_tokens": cache_read_input_tokens,
205+
"output_tokens": output_tokens,
206+
"reasoning_tokens": 0
207+
}
208+
209+
return usage_dict
210+
186211
def _apply_cache_control(self, messages):
187212
"""Apply cache control to the last user message and system message (if applicable)"""
188213
cached_messages = []

src/llm/providers/claude_newapi_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ class ClaudeNewAPIClient(LLMProviderClientBase):
3535
def _create_client(self, config: DictConfig):
3636
if self.async_client:
3737
return AsyncOpenAI(
38-
api_key=config.env.newapi_api_key,
39-
base_url=config.env.newapi_base_url,
38+
api_key=self.cfg.llm.newapi_api_key,
39+
base_url=self.cfg.llm.newapi_base_url,
4040
)
4141
else:
4242
return OpenAI(
43-
api_key=config.env.newapi_api_key,
44-
base_url=config.env.newapi_base_url,
43+
api_key=self.cfg.llm.newapi_api_key,
44+
base_url=self.cfg.llm.newapi_base_url,
4545
)
4646

4747
# @retry(wait=wait_fixed(10), stop=stop_after_attempt(5))

src/llm/providers/claude_openrouter_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ async def _create_message(
133133
if self.repetition_penalty != 1.0:
134134
extra_body["repetition_penalty"] = self.repetition_penalty
135135

136+
extra_body["usage"] = {
137+
"include": True
138+
}
139+
136140
params = {
137141
"model": self.model_name,
138142
"temperature": temperature,

src/llm/providers/deepseek_newapi_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def _create_client(self, config: DictConfig):
3232
"""Create configured OpenAI client"""
3333
if self.async_client:
3434
return AsyncOpenAI(
35-
api_key=config.env.newapi_api_key,
36-
base_url=config.env.newapi_base_url,
35+
api_key=self.cfg.llm.newapi_api_key,
36+
base_url=self.cfg.llm.newapi_base_url,
3737
)
3838
else:
3939
return OpenAI(
40-
api_key=config.env.newapi_api_key,
41-
base_url=config.env.newapi_base_url,
40+
api_key=self.cfg.llm.newapi_api_key,
41+
base_url=self.cfg.llm.newapi_base_url,
4242
)
4343

4444
# @retry(wait=wait_fixed(10), stop=stop_after_attempt(5))

src/llm/providers/gpt_openai_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def _create_client(self, config: DictConfig):
2929
"""Create configured OpenAI client"""
3030
if self.async_client:
3131
return AsyncOpenAI(
32-
api_key=config.env.openai_api_key,
33-
base_url=config.env.openai_base_url,
32+
api_key=self.cfg.llm.openai_api_key,
33+
base_url=self.cfg.llm.openai_base_url,
3434
)
3535
else:
3636
return OpenAI(
37-
api_key=config.env.openai_api_key,
38-
base_url=config.env.openai_base_url,
37+
api_key=self.cfg.llm.openai_api_key,
38+
base_url=self.cfg.llm.openai_base_url,
3939
)
4040

4141
@retry(wait=wait_fixed(10), stop=stop_after_attempt(5))

src/llm/providers/gpt_openai_response_client.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def _create_client(self, config: DictConfig):
2929
"""Create configured OpenAI client"""
3030
if self.async_client:
3131
return AsyncOpenAI(
32-
api_key=os.environ.get("OPENAI_API_KEY"),
32+
api_key=self.cfg.llm.openai_api_key,
3333
base_url=self.cfg.llm.openai_base_url,
3434
)
3535
else:
3636
return OpenAI(
37-
api_key=os.environ.get("OPENAI_API_KEY"),
37+
api_key=self.cfg.llm.openai_api_key,
3838
base_url=self.cfg.llm.openai_base_url,
3939
)
4040

@@ -93,7 +93,7 @@ async def _create_message(
9393
response = self._convert_response_to_serializable(response)
9494

9595
# Update token count
96-
self._update_token_usage(response.get("usage", None))
96+
# self._update_token_usage(response.get("usage", None))
9797
logger.debug(
9898
f"LLM Response API call status: {response.get('error', 'N/A')}"
9999
)
@@ -269,3 +269,30 @@ def _convert_response_to_serializable(self, response):
269269
}
270270

271271
return serializable_response
272+
273+
def _extract_usage_from_response(self, response):
274+
"""Extract usage - OpenAI Responses API format"""
275+
if not response or not response.get('usage'):
276+
return {
277+
"input_tokens": 0,
278+
"cached_tokens": 0,
279+
"output_tokens": 0,
280+
"reasoning_tokens": 0
281+
}
282+
283+
usage = response.get('usage', {}) or {}
284+
input_tokens_details = usage.get('input_tokens_details', {}) or {}
285+
if hasattr(input_tokens_details, "to_dict"):
286+
input_tokens_details = input_tokens_details.to_dict()
287+
output_tokens_details = usage.get('output_tokens_details', {}) or {}
288+
if hasattr(output_tokens_details, "to_dict"):
289+
output_tokens_details = output_tokens_details.to_dict()
290+
291+
usage_dict = {
292+
"input_tokens": usage.get('input_tokens', 0),
293+
"cached_tokens": input_tokens_details.get('cached_tokens', 0),
294+
"output_tokens": usage.get('output_tokens', 0),
295+
"reasoning_tokens": output_tokens_details.get('reasoning_tokens', 0)
296+
}
297+
298+
return usage_dict

src/llm/providers/qwen_sglang_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ def _create_client(self, config: DictConfig):
2525
"""Create configured OpenAI client"""
2626
if self.async_client:
2727
return AsyncOpenAI(
28-
api_key=config.env.qwen_api_key,
29-
base_url=config.env.qwen_base_url,
28+
api_key=self.cfg.llm.qwen_api_key,
29+
base_url=self.cfg.llm.qwen_base_url,
3030
)
3131
else:
3232
return OpenAI(
33-
api_key=config.env.qwen_api_key,
34-
base_url=config.env.qwen_base_url,
33+
api_key=self.cfg.llm.qwen_api_key,
34+
base_url=self.cfg.llm.qwen_base_url,
3535
)
3636

3737
@retry(wait=wait_fixed(10), stop=stop_after_attempt(5))

0 commit comments

Comments
 (0)