Skip to content

Commit f26069b

Browse files
committed
Merge commit 'fd90189c820a12d0d6a4c5266d20d36d6d158379'
2 parents 119b2d9 + fd90189 commit f26069b

File tree

8 files changed

+158
-15
lines changed

8 files changed

+158
-15
lines changed

src/core/orchestrator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,9 @@ async def run_sub_agent(
390390

391391
# Start new sub-agent session
392392
self.task_log.start_sub_agent_session(sub_agent_name, task_description)
393+
394+
# Reset sub-agent usage stats for independent tracking
395+
self.sub_agent_llm_client.reset_usage_stats()
393396

394397
# Simplified initial user content (no file attachments)
395398
initial_user_content = [{"type": "text", "text": task_description}]
@@ -667,6 +670,14 @@ async def run_sub_agent(
667670
] = {"system_prompt": system_prompt, "message_history": message_history} # type: ignore
668671
self.task_log.save()
669672

673+
# Record sub-agent cumulative usage
674+
usage_log = self.sub_agent_llm_client.get_usage_log()
675+
self.task_log.log_step(
676+
"usage_calculation",
677+
usage_log,
678+
metadata={"session_id": self.task_log.current_sub_agent_session_id},
679+
)
680+
670681
self.task_log.end_sub_agent_session(sub_agent_name)
671682
self.task_log.log_step(
672683
"sub_agent_completed", f"Sub agent {sub_agent_name} completed", "info"
@@ -688,6 +699,9 @@ async def run_main_agent(
688699
if task_file_name:
689700
logger.debug(f"Associated File: {task_file_name}")
690701

702+
# Reset main agent usage stats for independent tracking
703+
self.llm_client.reset_usage_stats()
704+
691705
# 1. Process input
692706
initial_user_content, task_description = process_input(
693707
task_description, task_file_name
@@ -1103,6 +1117,14 @@ async def run_main_agent(
11031117
"task_completed", f"Main agent task {task_id} completed successfully"
11041118
)
11051119

1120+
# Record main agent cumulative usage
1121+
usage_log = self.llm_client.get_usage_log()
1122+
self.task_log.log_step(
1123+
"usage_calculation",
1124+
usage_log,
1125+
metadata={"session_id": "main_agent"},
1126+
)
1127+
11061128
if "browsecomp-zh" in self.cfg.benchmark.name:
11071129
return final_summary, final_summary
11081130
else:

src/llm/provider_client_base.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class LLMProviderClientBase(ABC):
3434

3535
# post_init
3636
client: Any = dataclasses.field(init=False)
37+
# Usage tracking - cumulative for each agent session
38+
total_input_tokens: int = dataclasses.field(init=False, default=0)
39+
total_input_cached_tokens: int = dataclasses.field(init=False, default=0)
40+
total_output_tokens: int = dataclasses.field(init=False, default=0)
41+
total_output_reasoning_tokens: int = dataclasses.field(init=False, default=0)
3742

3843
def __post_init__(self):
3944
# Explicitly assign from cfg object
@@ -196,6 +201,19 @@ async def create_message(
196201
tool_definitions,
197202
keep_tool_result=keep_tool_result,
198203
)
204+
205+
# Accumulate usage for agent session
206+
if response:
207+
try:
208+
usage = self._extract_usage_from_response(response)
209+
if usage:
210+
self.total_input_tokens += usage.get("input_tokens", 0)
211+
self.total_input_cached_tokens += usage.get("cached_tokens", 0)
212+
self.total_output_tokens += usage.get("output_tokens", 0)
213+
self.total_output_reasoning_tokens += usage.get("reasoning_tokens", 0)
214+
except Exception as e:
215+
logger.warning(f"Failed to accumulate usage: {e}")
216+
199217
return response
200218

201219
@staticmethod
@@ -315,3 +333,50 @@ def handle_max_turns_reached_summary_prompt(
315333
self, message_history: list[dict[str, Any]], summary_prompt: str
316334
):
317335
raise NotImplementedError("must implement in subclass")
336+
337+
def _extract_usage_from_response(self, response):
338+
"""Default Extract usage - OpenAI Chat Completions format"""
339+
if not hasattr(response, 'usage'):
340+
return {
341+
"input_tokens": 0,
342+
"cached_tokens": 0,
343+
"output_tokens": 0,
344+
"reasoning_tokens": 0
345+
}
346+
347+
usage = response.usage
348+
prompt_tokens_details = getattr(usage, 'prompt_tokens_details', {}) or {}
349+
if hasattr(prompt_tokens_details, "to_dict"):
350+
prompt_tokens_details = prompt_tokens_details.to_dict()
351+
completion_tokens_details = getattr(usage, 'completion_tokens_details', {}) or {}
352+
if hasattr(completion_tokens_details, "to_dict"):
353+
completion_tokens_details = completion_tokens_details.to_dict()
354+
355+
usage_dict = {
356+
"input_tokens": getattr(usage, 'prompt_tokens', 0),
357+
"cached_tokens": prompt_tokens_details.get('cached_tokens', 0),
358+
"output_tokens": getattr(usage, 'completion_tokens', 0),
359+
"reasoning_tokens": completion_tokens_details.get('reasoning_tokens', 0)
360+
}
361+
362+
return usage_dict
363+
364+
def get_usage_log(self) -> str:
365+
"""Get cumulative usage for current agent session as formatted string"""
366+
# Format: [Provider | Model] Total Input: X, Cache Input: Y, Output: Z, ...
367+
provider_model = f"[{self.provider_class} | {self.model_name}]"
368+
input_uncached = self.total_input_tokens - self.total_input_cached_tokens
369+
output_response = self.total_output_tokens - self.total_output_reasoning_tokens
370+
total_tokens = self.total_input_tokens + self.total_output_tokens
371+
372+
return (f"Usage log: {provider_model}, "
373+
f"Total Input: {self.total_input_tokens} (Cached: {self.total_input_cached_tokens}, Uncached: {input_uncached}), "
374+
f"Total Output: {self.total_output_tokens} (Reasoning: {self.total_output_reasoning_tokens}, Response: {output_response}), "
375+
f"Total Tokens: {total_tokens}")
376+
377+
def reset_usage_stats(self):
378+
"""Reset usage stats for new agent session"""
379+
self.total_input_tokens = 0
380+
self.total_input_cached_tokens = 0
381+
self.total_output_tokens = 0
382+
self.total_output_reasoning_tokens = 0

src/llm/providers/archived/claude_newapi_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ class ClaudeNewAPIClient(LLMProviderClientBase):
3535
def _create_client(self, config: DictConfig):
3636
if self.async_client:
3737
return AsyncOpenAI(
38-
api_key=config.env.newapi_api_key,
39-
base_url=config.env.newapi_base_url,
38+
api_key=self.cfg.llm.newapi_api_key,
39+
base_url=self.cfg.llm.newapi_base_url,
4040
)
4141
else:
4242
return OpenAI(
43-
api_key=config.env.newapi_api_key,
44-
base_url=config.env.newapi_base_url,
43+
api_key=self.cfg.llm.newapi_api_key,
44+
base_url=self.cfg.llm.newapi_base_url,
4545
)
4646

4747
# @retry(wait=wait_fixed(10), stop=stop_after_attempt(5))

src/llm/providers/archived/deepseek_newapi_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def _create_client(self, config: DictConfig):
3232
"""Create configured OpenAI client"""
3333
if self.async_client:
3434
return AsyncOpenAI(
35-
api_key=config.env.newapi_api_key,
36-
base_url=config.env.newapi_base_url,
35+
api_key=self.cfg.llm.newapi_api_key,
36+
base_url=self.cfg.llm.newapi_base_url,
3737
)
3838
else:
3939
return OpenAI(
40-
api_key=config.env.newapi_api_key,
41-
base_url=config.env.newapi_base_url,
40+
api_key=self.cfg.llm.newapi_api_key,
41+
base_url=self.cfg.llm.newapi_base_url,
4242
)
4343

4444
# @retry(wait=wait_fixed(10), stop=stop_after_attempt(5))

src/llm/providers/archived/gpt_openai_response_client.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def _create_client(self, config: DictConfig):
2929
"""Create configured OpenAI client"""
3030
if self.async_client:
3131
return AsyncOpenAI(
32-
api_key=os.environ.get("OPENAI_API_KEY"),
32+
api_key=self.cfg.llm.openai_api_key,
3333
base_url=self.cfg.llm.openai_base_url,
3434
)
3535
else:
3636
return OpenAI(
37-
api_key=os.environ.get("OPENAI_API_KEY"),
37+
api_key=self.cfg.llm.openai_api_key,
3838
base_url=self.cfg.llm.openai_base_url,
3939
)
4040

@@ -93,7 +93,7 @@ async def _create_message(
9393
response = self._convert_response_to_serializable(response)
9494

9595
# Update token count
96-
self._update_token_usage(response.get("usage", None))
96+
# self._update_token_usage(response.get("usage", None))
9797
logger.debug(
9898
f"LLM Response API call status: {response.get('error', 'N/A')}"
9999
)
@@ -269,3 +269,30 @@ def _convert_response_to_serializable(self, response):
269269
}
270270

271271
return serializable_response
272+
273+
def _extract_usage_from_response(self, response):
274+
"""Extract usage - OpenAI Responses API format"""
275+
if not response or not response.get('usage'):
276+
return {
277+
"input_tokens": 0,
278+
"cached_tokens": 0,
279+
"output_tokens": 0,
280+
"reasoning_tokens": 0
281+
}
282+
283+
usage = response.get('usage', {}) or {}
284+
input_tokens_details = usage.get('input_tokens_details', {}) or {}
285+
if hasattr(input_tokens_details, "to_dict"):
286+
input_tokens_details = input_tokens_details.to_dict()
287+
output_tokens_details = usage.get('output_tokens_details', {}) or {}
288+
if hasattr(output_tokens_details, "to_dict"):
289+
output_tokens_details = output_tokens_details.to_dict()
290+
291+
usage_dict = {
292+
"input_tokens": usage.get('input_tokens', 0),
293+
"cached_tokens": input_tokens_details.get('cached_tokens', 0),
294+
"output_tokens": usage.get('output_tokens', 0),
295+
"reasoning_tokens": output_tokens_details.get('reasoning_tokens', 0)
296+
}
297+
298+
return usage_dict

src/llm/providers/archived/qwen_sglang_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ def _create_client(self, config: DictConfig):
2525
"""Create configured OpenAI client"""
2626
if self.async_client:
2727
return AsyncOpenAI(
28-
api_key=config.env.qwen_api_key,
29-
base_url=config.env.qwen_base_url,
28+
api_key=self.cfg.llm.qwen_api_key,
29+
base_url=self.cfg.llm.qwen_base_url,
3030
)
3131
else:
3232
return OpenAI(
33-
api_key=config.env.qwen_api_key,
34-
base_url=config.env.qwen_base_url,
33+
api_key=self.cfg.llm.qwen_api_key,
34+
base_url=self.cfg.llm.qwen_base_url,
3535
)
3636

3737
@retry(wait=wait_fixed(10), stop=stop_after_attempt(5))

src/llm/providers/claude_anthropic_client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,31 @@ def handle_max_turns_reached_summary_prompt(self, message_history, summary_promp
185185
else:
186186
return summary_prompt
187187

188+
def _extract_usage_from_response(self, response):
189+
"""Extract usage - Anthropic format"""
190+
if not hasattr(response, 'usage'):
191+
return {
192+
"input_tokens": 0,
193+
"cached_tokens": 0,
194+
"output_tokens": 0,
195+
"reasoning_tokens": 0
196+
}
197+
198+
usage = response.usage
199+
cache_creation_input_tokens = getattr(usage, 'cache_creation_input_tokens', 0)
200+
cache_read_input_tokens = getattr(usage, 'cache_read_input_tokens', 0)
201+
input_tokens = getattr(usage, 'input_tokens', 0)
202+
output_tokens = getattr(usage, 'output_tokens', 0)
203+
204+
usage_dict = {
205+
"input_tokens": cache_creation_input_tokens + cache_read_input_tokens + input_tokens,
206+
"cached_tokens": cache_read_input_tokens,
207+
"output_tokens": output_tokens,
208+
"reasoning_tokens": 0
209+
}
210+
211+
return usage_dict
212+
188213
def _apply_cache_control(self, messages):
189214
"""Apply cache control to the last user message and system message (if applicable)"""
190215
cached_messages = []

src/llm/providers/claude_openrouter_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ async def _create_message(
133133
if self.repetition_penalty != 1.0:
134134
extra_body["repetition_penalty"] = self.repetition_penalty
135135

136+
extra_body["usage"] = {
137+
"include": True
138+
}
139+
136140
params = {
137141
"model": self.model_name,
138142
"temperature": temperature,

0 commit comments

Comments
 (0)