Skip to content

Commit 7eff0c0

Browse files
committed
Add .set_token_limit() method to automatically drop old turns when specified limits are reached
1 parent cb225ae commit 7eff0c0

File tree

1 file changed

+204
-5
lines changed

1 file changed

+204
-5
lines changed

chatlas/_chat.py

Lines changed: 204 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
self.provider = provider
8989
self._turns: list[Turn] = list(turns or [])
9090
self._tools: dict[str, Tool] = {}
91+
self.token_limits: Optional[tuple[int, int]] = None
9192
self._echo_options: EchoOptions = {
9293
"rich_markdown": {},
9394
"rich_console": {},
@@ -176,17 +177,87 @@ def system_prompt(self, value: str | None):
176177
if value is not None:
177178
self._turns.insert(0, Turn("system", value))
178179

179-
def tokens(self) -> list[tuple[int, int] | None]:
180+
def tokens(self) -> list[int]:
180181
"""
181182
Get the tokens for each turn in the chat.
182183
183184
Returns
184185
-------
185-
list[tuple[int, int] | None]
186-
A list of tuples, where each tuple contains the start and end token
187-
indices for a turn.
186+
list[int]
187+
A list of token counts for each turn in the chat. Note that the
188+
1st turn includes the tokens count for the system prompt (if any).
189+
190+
Raises
191+
------
192+
ValueError
193+
If the chat's turns (i.e., `.get_turns()`) are not in an expected
194+
format. This may happen if the chat history is manually set (i.e.,
195+
`.set_turns()`). In this case, you can inspect the "raw" token
196+
values via the `.get_turns()` method (each turn has a `.tokens`
197+
attribute).
188198
"""
189-
return [turn.tokens for turn in self._turns]
199+
200+
turns = self.get_turns(include_system_prompt=False)
201+
202+
if len(turns) == 0:
203+
return []
204+
205+
err_info = (
206+
"This can happen if the chat history is manually set (i.e., `.set_turns()`). "
207+
"Consider getting the 'raw' token values via the `.get_turns()` method "
208+
"(each turn has a `.tokens` attribute)."
209+
)
210+
211+
# Sanity checks for the assumptions made to figure out user token counts
212+
if len(turns) == 1:
213+
raise ValueError(
214+
"Expected at least two turns in the chat history. " + err_info
215+
)
216+
217+
if len(turns) % 2 != 0:
218+
raise ValueError(
219+
"Expected an even number of turns in the chat history. " + err_info
220+
)
221+
222+
if turns[0].role != "user":
223+
raise ValueError("Expected the first turn to have role='user'. " + err_info)
224+
225+
if turns[1].role != "assistant":
226+
raise ValueError(
227+
"Expected the 2nd turn to have role='assistant'. " + err_info
228+
)
229+
230+
if turns[1].tokens is None:
231+
raise ValueError(
232+
"Expected the 1st assistant turn to contain token counts. " + err_info
233+
)
234+
235+
tokens: list[int] = [
236+
turns[1].tokens[0],
237+
sum(turns[1].tokens),
238+
]
239+
for i in range(1, len(turns) - 1, 2):
240+
ti = turns[i]
241+
tj = turns[i + 2]
242+
if ti.role != "assistant" or tj.role != "assistant":
243+
raise ValueError(
244+
"Expected even turns to have role='assistant'." + err_info
245+
)
246+
if ti.tokens is None or tj.tokens is None:
247+
raise ValueError(
248+
"Expected role='assistant' turns to contain token counts."
249+
+ err_info
250+
)
251+
tokens.extend(
252+
[
253+
# Implied token count for the user input
254+
tj.tokens[0] - sum(ti.tokens),
255+
# The token count for the assistant response
256+
tj.tokens[1],
257+
]
258+
)
259+
260+
return tokens
190261

191262
def token_count(
192263
self,
@@ -269,6 +340,121 @@ async def token_count_async(
269340
data_model=data_model,
270341
)
271342

343+
def set_token_limits(
344+
self,
345+
context_window: int,
346+
max_tokens: int,
347+
):
348+
"""
349+
Set a limit on the number of tokens that can be sent to the model.
350+
351+
By default, the size of the chat history is unbounded -- it keeps
352+
growing as you submit more input. This can be wasteful if you don't
353+
need to keep the entire chat history around, and can also lead to
354+
errors if the chat history gets too large for the model to handle.
355+
356+
This method allows you to set a limit to the number of tokens that can
357+
be sent to the model. If the limit is exceeded, the chat history will be
358+
truncated to fit within the limit (i.e., the oldest turns will be
359+
dropped).
360+
361+
Note that many models publish a context window as well as a maximum
362+
output token limit. For example,
363+
364+
<https://platform.openai.com/docs/models/gp#gpt-4o-realtime>
365+
<https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table>
366+
367+
Also, since the context window is the maximum number of input + output
368+
tokens, the maximum number of tokens that can be sent to the model in a
369+
single request is `context_window - max_tokens`.
370+
371+
Parameters
372+
----------
373+
context_window
374+
The maximum number of tokens that can be sent to the model.
375+
max_tokens
376+
The maximum number of tokens that the model is allowed to generate
377+
in a single response.
378+
379+
Note
380+
----
381+
This method uses `.token_count()` to estimate the token count for new input
382+
before truncating the chat history. This is an estimate, so it may not be
383+
perfect. Morever, any chat models based on `ChatOpenAI()` currently do not
384+
take the tool loop into account when estimating token counts. This means, if
385+
your input will trigger many tool calls, and/or the tool results are large,
386+
it's recommended to set a conservative limit on the `context_window`.
387+
388+
Examples
389+
--------
390+
```python
391+
from chatlas import ChatOpenAI
392+
393+
chat = ChatOpenAI(model="claude-3-5-sonnet-20241022")
394+
chat.set_token_limit(200000, 8192)
395+
```
396+
"""
397+
if max_tokens >= context_window:
398+
raise ValueError("`max_tokens` must be less than the `context_window`.")
399+
self.token_limits = (context_window, max_tokens)
400+
401+
def _maybe_drop_turns(
402+
self,
403+
*args: Content | str,
404+
data_model: Optional[type[BaseModel]] = None,
405+
):
406+
"""
407+
Drop turns from the chat history if they exceed the token limits.
408+
"""
409+
410+
# Do nothing if token limits are not set
411+
if self.token_limits is None:
412+
return None
413+
414+
turns = self.get_turns(include_system_prompt=False)
415+
416+
# Do nothing if this is the first turn
417+
if len(turns) == 0:
418+
return None
419+
420+
last_turn = turns[-1]
421+
422+
# Sanity checks (i.e., when about to submit new input, the last turn should
423+
# be from the assistant and should contain token counts)
424+
if last_turn.role != "assistant":
425+
raise ValueError(
426+
"Expected the last turn must be from the assistant. Please report this issue."
427+
)
428+
429+
if last_turn.tokens is None:
430+
raise ValueError(
431+
"Can't impose token limits since assistant turns contain token counts. "
432+
"Please report this issue and consider setting `.token_limits` to `None`."
433+
)
434+
435+
context_window, max_tokens = self.token_limits
436+
max_input_size = context_window - max_tokens
437+
438+
# Estimate the token count for the (new) user turn
439+
input_tokens = self.token_count(*args, data_model=data_model)
440+
441+
# Do nothing if current history size plus input size is within the limit
442+
remaining_tokens = max_input_size - input_tokens
443+
if sum(last_turn.tokens) < remaining_tokens:
444+
return self
445+
446+
tokens = self.tokens()
447+
448+
# Drop turns until they (plus the new input) fit within the token limits
449+
# TODO: we also need to account for the fact that dropping part of a tool loop is problematic
450+
while sum(tokens) >= remaining_tokens:
451+
del turns[2:]
452+
del tokens[2:]
453+
454+
self.set_turns(turns)
455+
456+
return None
457+
272458
def app(
273459
self,
274460
*,
@@ -419,6 +605,8 @@ def chat(
419605
A (consumed) response from the chat. Apply `str()` to this object to
420606
get the text content of the response.
421607
"""
608+
self._maybe_drop_turns(*args)
609+
422610
turn = user_turn(*args)
423611

424612
display = self._markdown_display(echo=echo)
@@ -469,6 +657,9 @@ async def chat_async(
469657
A (consumed) response from the chat. Apply `str()` to this object to
470658
get the text content of the response.
471659
"""
660+
# TODO: async version?
661+
self._maybe_drop_turns(*args)
662+
472663
turn = user_turn(*args)
473664

474665
display = self._markdown_display(echo=echo)
@@ -515,6 +706,8 @@ def stream(
515706
An (unconsumed) response from the chat. Iterate over this object to
516707
consume the response.
517708
"""
709+
self._maybe_drop_turns(*args)
710+
518711
turn = user_turn(*args)
519712

520713
display = self._markdown_display(echo=echo)
@@ -560,6 +753,9 @@ async def stream_async(
560753
An (unconsumed) response from the chat. Iterate over this object to
561754
consume the response.
562755
"""
756+
# TODO: async version?
757+
self._maybe_drop_turns(*args)
758+
563759
turn = user_turn(*args)
564760

565761
display = self._markdown_display(echo=echo)
@@ -603,6 +799,7 @@ def extract_data(
603799
dict[str, Any]
604800
The extracted data.
605801
"""
802+
self._maybe_drop_turns(*args, data_model=data_model)
606803

607804
display = self._markdown_display(echo=echo)
608805

@@ -663,6 +860,8 @@ async def extract_data_async(
663860
dict[str, Any]
664861
The extracted data.
665862
"""
863+
# TODO: async version?
864+
self._maybe_drop_turns(*args, data_model=data_model)
666865

667866
display = self._markdown_display(echo=echo)
668867

0 commit comments

Comments
 (0)