@@ -88,6 +88,7 @@ def __init__(
8888 self .provider = provider
8989 self ._turns : list [Turn ] = list (turns or [])
9090 self ._tools : dict [str , Tool ] = {}
91+ self .token_limits : Optional [tuple [int , int ]] = None
9192 self ._echo_options : EchoOptions = {
9293 "rich_markdown" : {},
9394 "rich_console" : {},
@@ -176,17 +177,87 @@ def system_prompt(self, value: str | None):
176177 if value is not None :
177178 self ._turns .insert (0 , Turn ("system" , value ))
178179
179- def tokens (self ) -> list [tuple [ int , int ] | None ]:
180+ def tokens (self ) -> list [int ]:
180181 """
181182 Get the tokens for each turn in the chat.
182183
183184 Returns
184185 -------
185- list[tuple[int, int] | None]
186- A list of tuples, where each tuple contains the start and end token
187- indices for a turn.
186+ list[int]
187+ A list of token counts for each turn in the chat. Note that the
188+ 1st turn includes the tokens count for the system prompt (if any).
189+
190+ Raises
191+ ------
192+ ValueError
193+ If the chat's turns (i.e., `.get_turns()`) are not in an expected
194+ format. This may happen if the chat history is manually set (i.e.,
195+ `.set_turns()`). In this case, you can inspect the "raw" token
196+ values via the `.get_turns()` method (each turn has a `.tokens`
197+ attribute).
188198 """
189- return [turn .tokens for turn in self ._turns ]
199+
200+ turns = self .get_turns (include_system_prompt = False )
201+
202+ if len (turns ) == 0 :
203+ return []
204+
205+ err_info = (
206+ "This can happen if the chat history is manually set (i.e., `.set_turns()`). "
207+ "Consider getting the 'raw' token values via the `.get_turns()` method "
208+ "(each turn has a `.tokens` attribute)."
209+ )
210+
211+ # Sanity checks for the assumptions made to figure out user token counts
212+ if len (turns ) == 1 :
213+ raise ValueError (
214+ "Expected at least two turns in the chat history. " + err_info
215+ )
216+
217+ if len (turns ) % 2 != 0 :
218+ raise ValueError (
219+ "Expected an even number of turns in the chat history. " + err_info
220+ )
221+
222+ if turns [0 ].role != "user" :
223+ raise ValueError ("Expected the first turn to have role='user'. " + err_info )
224+
225+ if turns [1 ].role != "assistant" :
226+ raise ValueError (
227+ "Expected the 2nd turn to have role='assistant'. " + err_info
228+ )
229+
230+ if turns [1 ].tokens is None :
231+ raise ValueError (
232+ "Expected the 1st assistant turn to contain token counts. " + err_info
233+ )
234+
235+ tokens : list [int ] = [
236+ turns [1 ].tokens [0 ],
237+ sum (turns [1 ].tokens ),
238+ ]
239+ for i in range (1 , len (turns ) - 1 , 2 ):
240+ ti = turns [i ]
241+ tj = turns [i + 2 ]
242+ if ti .role != "assistant" or tj .role != "assistant" :
243+ raise ValueError (
244+ "Expected even turns to have role='assistant'." + err_info
245+ )
246+ if ti .tokens is None or tj .tokens is None :
247+ raise ValueError (
248+ "Expected role='assistant' turns to contain token counts."
249+ + err_info
250+ )
251+ tokens .extend (
252+ [
253+ # Implied token count for the user input
254+ tj .tokens [0 ] - sum (ti .tokens ),
255+ # The token count for the assistant response
256+ tj .tokens [1 ],
257+ ]
258+ )
259+
260+ return tokens
190261
191262 def token_count (
192263 self ,
@@ -269,6 +340,121 @@ async def token_count_async(
269340 data_model = data_model ,
270341 )
271342
343+ def set_token_limits (
344+ self ,
345+ context_window : int ,
346+ max_tokens : int ,
347+ ):
348+ """
349+ Set a limit on the number of tokens that can be sent to the model.
350+
351+ By default, the size of the chat history is unbounded -- it keeps
352+ growing as you submit more input. This can be wasteful if you don't
353+ need to keep the entire chat history around, and can also lead to
354+ errors if the chat history gets too large for the model to handle.
355+
356+ This method allows you to set a limit to the number of tokens that can
357+ be sent to the model. If the limit is exceeded, the chat history will be
358+ truncated to fit within the limit (i.e., the oldest turns will be
359+ dropped).
360+
361+ Note that many models publish a context window as well as a maximum
362+ output token limit. For example,
363+
364+ <https://platform.openai.com/docs/models/gp#gpt-4o-realtime>
365+ <https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table>
366+
367+ Also, since the context window is the maximum number of input + output
368+ tokens, the maximum number of tokens that can be sent to the model in a
369+ single request is `context_window - max_tokens`.
370+
371+ Parameters
372+ ----------
373+ context_window
374+ The maximum number of tokens that can be sent to the model.
375+ max_tokens
376+ The maximum number of tokens that the model is allowed to generate
377+ in a single response.
378+
379+ Note
380+ ----
381+ This method uses `.token_count()` to estimate the token count for new input
382+ before truncating the chat history. This is an estimate, so it may not be
383+ perfect. Morever, any chat models based on `ChatOpenAI()` currently do not
384+ take the tool loop into account when estimating token counts. This means, if
385+ your input will trigger many tool calls, and/or the tool results are large,
386+ it's recommended to set a conservative limit on the `context_window`.
387+
388+ Examples
389+ --------
390+ ```python
391+ from chatlas import ChatOpenAI
392+
393+ chat = ChatOpenAI(model="claude-3-5-sonnet-20241022")
394+ chat.set_token_limit(200000, 8192)
395+ ```
396+ """
397+ if max_tokens >= context_window :
398+ raise ValueError ("`max_tokens` must be less than the `context_window`." )
399+ self .token_limits = (context_window , max_tokens )
400+
401+ def _maybe_drop_turns (
402+ self ,
403+ * args : Content | str ,
404+ data_model : Optional [type [BaseModel ]] = None ,
405+ ):
406+ """
407+ Drop turns from the chat history if they exceed the token limits.
408+ """
409+
410+ # Do nothing if token limits are not set
411+ if self .token_limits is None :
412+ return None
413+
414+ turns = self .get_turns (include_system_prompt = False )
415+
416+ # Do nothing if this is the first turn
417+ if len (turns ) == 0 :
418+ return None
419+
420+ last_turn = turns [- 1 ]
421+
422+ # Sanity checks (i.e., when about to submit new input, the last turn should
423+ # be from the assistant and should contain token counts)
424+ if last_turn .role != "assistant" :
425+ raise ValueError (
426+ "Expected the last turn must be from the assistant. Please report this issue."
427+ )
428+
429+ if last_turn .tokens is None :
430+ raise ValueError (
431+ "Can't impose token limits since assistant turns contain token counts. "
432+ "Please report this issue and consider setting `.token_limits` to `None`."
433+ )
434+
435+ context_window , max_tokens = self .token_limits
436+ max_input_size = context_window - max_tokens
437+
438+ # Estimate the token count for the (new) user turn
439+ input_tokens = self .token_count (* args , data_model = data_model )
440+
441+ # Do nothing if current history size plus input size is within the limit
442+ remaining_tokens = max_input_size - input_tokens
443+ if sum (last_turn .tokens ) < remaining_tokens :
444+ return self
445+
446+ tokens = self .tokens ()
447+
448+ # Drop turns until they (plus the new input) fit within the token limits
449+ # TODO: we also need to account for the fact that dropping part of a tool loop is problematic
450+ while sum (tokens ) >= remaining_tokens :
451+ del turns [2 :]
452+ del tokens [2 :]
453+
454+ self .set_turns (turns )
455+
456+ return None
457+
272458 def app (
273459 self ,
274460 * ,
@@ -419,6 +605,8 @@ def chat(
419605 A (consumed) response from the chat. Apply `str()` to this object to
420606 get the text content of the response.
421607 """
608+ self ._maybe_drop_turns (* args )
609+
422610 turn = user_turn (* args )
423611
424612 display = self ._markdown_display (echo = echo )
@@ -469,6 +657,9 @@ async def chat_async(
469657 A (consumed) response from the chat. Apply `str()` to this object to
470658 get the text content of the response.
471659 """
660+ # TODO: async version?
661+ self ._maybe_drop_turns (* args )
662+
472663 turn = user_turn (* args )
473664
474665 display = self ._markdown_display (echo = echo )
@@ -515,6 +706,8 @@ def stream(
515706 An (unconsumed) response from the chat. Iterate over this object to
516707 consume the response.
517708 """
709+ self ._maybe_drop_turns (* args )
710+
518711 turn = user_turn (* args )
519712
520713 display = self ._markdown_display (echo = echo )
@@ -560,6 +753,9 @@ async def stream_async(
560753 An (unconsumed) response from the chat. Iterate over this object to
561754 consume the response.
562755 """
756+ # TODO: async version?
757+ self ._maybe_drop_turns (* args )
758+
563759 turn = user_turn (* args )
564760
565761 display = self ._markdown_display (echo = echo )
@@ -603,6 +799,7 @@ def extract_data(
603799 dict[str, Any]
604800 The extracted data.
605801 """
802+ self ._maybe_drop_turns (* args , data_model = data_model )
606803
607804 display = self ._markdown_display (echo = echo )
608805
@@ -663,6 +860,8 @@ async def extract_data_async(
663860 dict[str, Any]
664861 The extracted data.
665862 """
863+ # TODO: async version?
864+ self ._maybe_drop_turns (* args , data_model = data_model )
666865
667866 display = self ._markdown_display (echo = echo )
668867
0 commit comments