@@ -34,6 +34,11 @@ class LLMProviderClientBase(ABC):
3434
3535 # post_init
3636 client : Any = dataclasses .field (init = False )
37+ # Usage tracking - cumulative for each agent session
38+ total_input_tokens : int = dataclasses .field (init = False , default = 0 )
39+ total_input_cached_tokens : int = dataclasses .field (init = False , default = 0 )
40+ total_output_tokens : int = dataclasses .field (init = False , default = 0 )
41+ total_output_reasoning_tokens : int = dataclasses .field (init = False , default = 0 )
3742
3843 def __post_init__ (self ):
3944 # Explicitly assign from cfg object
@@ -196,6 +201,19 @@ async def create_message(
196201 tool_definitions ,
197202 keep_tool_result = keep_tool_result ,
198203 )
204+
205+ # Accumulate usage for agent session
206+ if response :
207+ try :
208+ usage = self ._extract_usage_from_response (response )
209+ if usage :
210+ self .total_input_tokens += usage .get ("input_tokens" , 0 )
211+ self .total_input_cached_tokens += usage .get ("cached_tokens" , 0 )
212+ self .total_output_tokens += usage .get ("output_tokens" , 0 )
213+ self .total_output_reasoning_tokens += usage .get ("reasoning_tokens" , 0 )
214+ except Exception as e :
215+ logger .warning (f"Failed to accumulate usage: { e } " )
216+
199217 return response
200218
201219 @staticmethod
@@ -315,3 +333,50 @@ def handle_max_turns_reached_summary_prompt(
315333 self , message_history : list [dict [str , Any ]], summary_prompt : str
316334 ):
317335 raise NotImplementedError ("must implement in subclass" )
336+
337+ def _extract_usage_from_response (self , response ):
338+ """Default Extract usage - OpenAI Chat Completions format"""
339+ if not hasattr (response , 'usage' ):
340+ return {
341+ "input_tokens" : 0 ,
342+ "cached_tokens" : 0 ,
343+ "output_tokens" : 0 ,
344+ "reasoning_tokens" : 0
345+ }
346+
347+ usage = response .usage
348+ prompt_tokens_details = getattr (usage , 'prompt_tokens_details' , {}) or {}
349+ if hasattr (prompt_tokens_details , "to_dict" ):
350+ prompt_tokens_details = prompt_tokens_details .to_dict ()
351+ completion_tokens_details = getattr (usage , 'completion_tokens_details' , {}) or {}
352+ if hasattr (completion_tokens_details , "to_dict" ):
353+ completion_tokens_details = completion_tokens_details .to_dict ()
354+
355+ usage_dict = {
356+ "input_tokens" : getattr (usage , 'prompt_tokens' , 0 ),
357+ "cached_tokens" : prompt_tokens_details .get ('cached_tokens' , 0 ),
358+ "output_tokens" : getattr (usage , 'completion_tokens' , 0 ),
359+ "reasoning_tokens" : completion_tokens_details .get ('reasoning_tokens' , 0 )
360+ }
361+
362+ return usage_dict
363+
364+ def get_usage_log (self ) -> str :
365+ """Get cumulative usage for current agent session as formatted string"""
366+ # Format: [Provider | Model] Total Input: X, Cache Input: Y, Output: Z, ...
367+ provider_model = f"[{ self .provider_class } | { self .model_name } ]"
368+ input_uncached = self .total_input_tokens - self .total_input_cached_tokens
369+ output_response = self .total_output_tokens - self .total_output_reasoning_tokens
370+ total_tokens = self .total_input_tokens + self .total_output_tokens
371+
372+ return (f"Usage log: { provider_model } , "
373+ f"Total Input: { self .total_input_tokens } (Cached: { self .total_input_cached_tokens } , Uncached: { input_uncached } ), "
374+ f"Total Output: { self .total_output_tokens } (Reasoning: { self .total_output_reasoning_tokens } , Response: { output_response } ), "
375+ f"Total Tokens: { total_tokens } " )
376+
377+ def reset_usage_stats (self ):
378+ """Reset usage stats for new agent session"""
379+ self .total_input_tokens = 0
380+ self .total_input_cached_tokens = 0
381+ self .total_output_tokens = 0
382+ self .total_output_reasoning_tokens = 0
0 commit comments