@@ -34,6 +34,11 @@ class LLMProviderClientBase(ABC):
3434
3535 # post_init
3636 client : Any = dataclasses .field (init = False )
37+ # Usage tracking - cumulative for each agent session
38+ total_input_tokens : int = dataclasses .field (init = False , default = 0 )
39+ total_input_cached_tokens : int = dataclasses .field (init = False , default = 0 )
40+ total_output_tokens : int = dataclasses .field (init = False , default = 0 )
41+ total_output_reasoning_tokens : int = dataclasses .field (init = False , default = 0 )
3742
3843 def __post_init__ (self ):
3944 # Explicitly assign from cfg object
@@ -195,6 +200,19 @@ async def create_message(
195200 tool_definitions ,
196201 keep_tool_result = keep_tool_result ,
197202 )
203+
204+ # Accumulate usage for agent session
205+ if response :
206+ try :
207+ usage = self ._extract_usage_from_response (response )
208+ if usage :
209+ self .total_input_tokens += usage .get ("input_tokens" , 0 )
210+ self .total_input_cached_tokens += usage .get ("cached_tokens" , 0 )
211+ self .total_output_tokens += usage .get ("output_tokens" , 0 )
212+ self .total_output_reasoning_tokens += usage .get ("reasoning_tokens" , 0 )
213+ except Exception as e :
214+ logger .warning (f"Failed to accumulate usage: { e } " )
215+
198216 return response
199217
200218 @staticmethod
@@ -314,3 +332,50 @@ def handle_max_turns_reached_summary_prompt(
314332 self , message_history : list [dict [str , Any ]], summary_prompt : str
315333 ):
316334 raise NotImplementedError ("must implement in subclass" )
335+
336+ def _extract_usage_from_response (self , response ):
337+ """Default Extract usage - OpenAI Chat Completions format"""
338+ if not hasattr (response , 'usage' ):
339+ return {
340+ "input_tokens" : 0 ,
341+ "cached_tokens" : 0 ,
342+ "output_tokens" : 0 ,
343+ "reasoning_tokens" : 0
344+ }
345+
346+ usage = response .usage
347+ prompt_tokens_details = getattr (usage , 'prompt_tokens_details' , {}) or {}
348+ if hasattr (prompt_tokens_details , "to_dict" ):
349+ prompt_tokens_details = prompt_tokens_details .to_dict ()
350+ completion_tokens_details = getattr (usage , 'completion_tokens_details' , {}) or {}
351+ if hasattr (completion_tokens_details , "to_dict" ):
352+ completion_tokens_details = completion_tokens_details .to_dict ()
353+
354+ usage_dict = {
355+ "input_tokens" : getattr (usage , 'prompt_tokens' , 0 ),
356+ "cached_tokens" : prompt_tokens_details .get ('cached_tokens' , 0 ),
357+ "output_tokens" : getattr (usage , 'completion_tokens' , 0 ),
358+ "reasoning_tokens" : completion_tokens_details .get ('reasoning_tokens' , 0 )
359+ }
360+
361+ return usage_dict
362+
363+ def get_usage_log (self ) -> str :
364+ """Get cumulative usage for current agent session as formatted string"""
365+ # Format: [Provider | Model] Total Input: X, Cache Input: Y, Output: Z, ...
366+ provider_model = f"[{ self .provider_class } | { self .model_name } ]"
367+ input_uncached = self .total_input_tokens - self .total_input_cached_tokens
368+ output_response = self .total_output_tokens - self .total_output_reasoning_tokens
369+ total_tokens = self .total_input_tokens + self .total_output_tokens
370+
371+ return (f"Usage log: { provider_model } , "
372+ f"Total Input: { self .total_input_tokens } (Cached: { self .total_input_cached_tokens } , Uncached: { input_uncached } ), "
373+ f"Total Output: { self .total_output_tokens } (Reasoning: { self .total_output_reasoning_tokens } , Response: { output_response } ), "
374+ f"Total Tokens: { total_tokens } " )
375+
376+ def reset_usage_stats (self ):
377+ """Reset usage stats for new agent session"""
378+ self .total_input_tokens = 0
379+ self .total_input_cached_tokens = 0
380+ self .total_output_tokens = 0
381+ self .total_output_reasoning_tokens = 0
0 commit comments