Skip to content

Commit c006481

Browse files
authored
feat: Add new kwargs_chat attribute for persistent chat input parameters (#212)
* feat: Add new Chat(kwargs={}) parameter for persistant chat input parameters * Improve naming, docs, and tests * Fix test
1 parent 97b0b48 commit c006481

File tree

3 files changed

+33
-69
lines changed

3 files changed

+33
-69
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
### Changes
2323

2424
* `ChatOpenAI()` (and `ChatAzureOpenAI()`) move from OpenAI's Completions API to [Responses API](https://platform.openai.com/docs/api-reference/responses). If this happens to break behavior, change `ChatOpenAI()` -> `ChatOpenAICompletions()` (or `ChatAzureOpenAI()` -> `ChatAzureOpenAICompletions()`). (#192)
25+
* The `.set_model_params()` method no longer accepts `kwargs`. Instead, use the new `chat.kwargs_chat` attribute to set chat input parameters that persist across the chat session. (#212)
2526
* `Provider` implementations now require an additional `.value_tokens()` method. Previously, it was assumed that token info was logged and attached to the `Turn` as part of the `.value_turn()` method. The logging and attaching is now handled automatically. (#194)
2627

2728
### Improvements

chatlas/_chat.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
self,
109109
provider: Provider,
110110
system_prompt: Optional[str] = None,
111+
kwargs_chat: Optional[SubmitInputArgsT] = None,
111112
):
112113
"""
113114
Create a new chat object.
@@ -118,10 +119,17 @@ def __init__(
118119
A [](`~chatlas.Provider`) object.
119120
system_prompt
120121
A system prompt to set the behavior of the assistant.
122+
kwargs_chat
123+
Additional arguments to pass to the provider when submitting input.
124+
These arguments persist across all chat interactions and will be
125+
merged with any kwargs passed to individual methods like `chat()` or
126+
`stream()`. They also take precedence over any parameters set via
127+
`set_model_params()`.
121128
"""
122129
self.provider = provider
123130
self._turns: list[Turn] = []
124131
self.system_prompt = system_prompt
132+
self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {}
125133

126134
self._tools: dict[str, Tool] = {}
127135
self._on_tool_request_callbacks = CallbackManager()
@@ -136,7 +144,6 @@ def __init__(
136144

137145
# Chat input parameters from `set_model_params()`
138146
self._standard_model_params: StandardModelParams = {}
139-
self._submit_input_kwargs: Optional[SubmitInputArgsT] = None
140147

141148
def list_models(self) -> list[ModelInfo]:
142149
"""
@@ -1454,7 +1461,6 @@ def set_model_params(
14541461
max_tokens: int | None | MISSING_TYPE = MISSING,
14551462
log_probs: bool | None | MISSING_TYPE = MISSING,
14561463
stop_sequences: list[str] | None | MISSING_TYPE = MISSING,
1457-
kwargs: SubmitInputArgsT | None | MISSING_TYPE = MISSING,
14581464
):
14591465
"""
14601466
Set common model parameters for the chat.
@@ -1488,10 +1494,6 @@ def set_model_params(
14881494
Include the log probabilities in the output?
14891495
stop_sequences
14901496
A character vector of tokens to stop generation on.
1491-
kwargs
1492-
Additional keyword arguments to use when submitting input to the
1493-
model. When calling this method repeatedly with different parameters,
1494-
only the parameters from the last call will be used.
14951497
"""
14961498

14971499
params: StandardModelParams = {}
@@ -1557,13 +1559,6 @@ def set_model_params(
15571559
# Update the standard model parameters
15581560
self._standard_model_params.update(params)
15591561

1560-
# Update the submit input kwargs
1561-
if kwargs is None:
1562-
self._submit_input_kwargs = None
1563-
1564-
if is_present(kwargs):
1565-
self._submit_input_kwargs = kwargs
1566-
15671562
async def register_mcp_tools_http_stream_async(
15681563
self,
15691564
*,
@@ -2640,8 +2635,8 @@ def _collect_all_kwargs(
26402635
)
26412636

26422637
# Add any additional kwargs provided by the user
2643-
if self._submit_input_kwargs:
2644-
all_kwargs.update(self._submit_input_kwargs)
2638+
if self.kwargs_chat:
2639+
all_kwargs.update(self.kwargs_chat)
26452640

26462641
if kwargs:
26472642
all_kwargs.update(kwargs)

tests/test_set_model_params.py

Lines changed: 22 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Tests for the set_model_params() feature."""
22

33
import pytest
4-
from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAICompletions as ChatOpenAI
4+
5+
from chatlas import ChatAnthropic, ChatGoogle
6+
from chatlas import ChatOpenAICompletions as ChatOpenAI
57
from chatlas._provider import StandardModelParams
68
from chatlas._utils import MISSING
79

@@ -53,20 +55,6 @@ def test_set_model_params_none_reset():
5355
assert "top_p" in params # Should still be there
5456

5557

56-
def test_set_model_params_kwargs():
57-
"""Test setting provider-specific kwargs."""
58-
chat = ChatOpenAI()
59-
60-
chat.set_model_params(kwargs={"frequency_penalty": 0.1, "presence_penalty": 0.2})
61-
kwargs = getattr(chat, "_submit_input_kwargs", {})
62-
assert kwargs == {"frequency_penalty": 0.1, "presence_penalty": 0.2}
63-
64-
# Reset kwargs to None
65-
chat.set_model_params(kwargs=None)
66-
kwargs = getattr(chat, "_submit_input_kwargs", {})
67-
assert kwargs is None
68-
69-
7058
def test_set_model_params_all_parameters():
7159
"""Test setting all supported standard parameters."""
7260
chat = ChatOpenAI()
@@ -260,23 +248,6 @@ def test_set_model_params_updates_existing():
260248
assert params["top_p"] == 0.9 # New
261249

262250

263-
def test_set_model_params_kwargs_replacement():
264-
"""Test that kwargs are completely replaced, not merged."""
265-
chat = ChatOpenAI()
266-
267-
# Set initial kwargs
268-
chat.set_model_params(kwargs={"frequency_penalty": 0.1, "presence_penalty": 0.2})
269-
kwargs = getattr(chat, "_submit_input_kwargs", {})
270-
assert kwargs == {"frequency_penalty": 0.1, "presence_penalty": 0.2}
271-
272-
# Set new kwargs - should completely replace
273-
chat.set_model_params(kwargs={"seed": 42})
274-
kwargs = getattr(chat, "_submit_input_kwargs", {})
275-
assert kwargs == {"seed": 42}
276-
assert "frequency_penalty" not in kwargs
277-
assert "presence_penalty" not in kwargs
278-
279-
280251
def test_set_model_params_invalid_temperature():
281252
"""Test validation of temperature parameter ranges."""
282253
chat = ChatOpenAI()
@@ -320,7 +291,6 @@ def test_set_model_params_empty_call():
320291

321292
# Should not change anything
322293
assert chat._standard_model_params == {}
323-
assert chat._submit_input_kwargs is None
324294

325295

326296
def test_set_model_params_type_validation():
@@ -388,27 +358,6 @@ def test_set_model_params_reset_specific_param():
388358
assert params["top_p"] == 0.95
389359

390360

391-
def test_model_params_kwargs_priority():
392-
"""Test that submit_input_kwargs override model_params."""
393-
chat = ChatOpenAI()
394-
395-
# Set model parameters
396-
chat.set_model_params(temperature=0.1, max_tokens=100)
397-
398-
# Set submit input kwargs that override some model params
399-
chat.set_model_params(kwargs={"temperature": 0.5, "seed": 42})
400-
401-
# Verify that kwargs are set correctly
402-
kwargs = getattr(chat, "_submit_input_kwargs", {})
403-
params = getattr(chat, "_standard_model_params", {})
404-
assert kwargs["temperature"] == 0.5
405-
assert kwargs["seed"] == 42
406-
407-
# Model params should still be stored
408-
assert params["temperature"] == 0.1
409-
assert params["max_tokens"] == 100
410-
411-
412361
def test_is_present_function():
413362
"""Test the is_present helper function used in set_model_params."""
414363
from chatlas._chat import is_present
@@ -576,3 +525,22 @@ def test_parameter_validation_edge_cases():
576525

577526
assert params["max_tokens"] == 1000000
578527
assert params["seed"] == 999999999
528+
529+
530+
def test_chat_kwargs_with_model_params():
531+
"""Test that Chat kwargs work alongside set_model_params."""
532+
chat = ChatOpenAI()
533+
534+
# Set model parameters
535+
chat.set_model_params(temperature=0.7, max_tokens=100)
536+
537+
# Set persistent chat kwargs
538+
chat.kwargs_chat = {"frequency_penalty": 0.1, "presence_penalty": 0.2}
539+
540+
# Collect all kwargs for a chat call
541+
kwargs = chat._collect_all_kwargs({"presence_penalty": 0.3})
542+
543+
assert kwargs.get("temperature") == 0.7
544+
assert kwargs.get("max_tokens") == 100
545+
assert kwargs.get("frequency_penalty") == 0.1
546+
assert kwargs.get("presence_penalty") == 0.3 # Should override chat.chat_kwargs

0 commit comments

Comments
 (0)