|
1 | 1 | """Tests for the set_model_params() feature.""" |
2 | 2 |
|
3 | 3 | import pytest |
4 | | -from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAICompletions as ChatOpenAI |
| 4 | + |
| 5 | +from chatlas import ChatAnthropic, ChatGoogle |
| 6 | +from chatlas import ChatOpenAICompletions as ChatOpenAI |
5 | 7 | from chatlas._provider import StandardModelParams |
6 | 8 | from chatlas._utils import MISSING |
7 | 9 |
|
@@ -53,20 +55,6 @@ def test_set_model_params_none_reset(): |
53 | 55 | assert "top_p" in params # Should still be there |
54 | 56 |
|
55 | 57 |
|
56 | | -def test_set_model_params_kwargs(): |
57 | | - """Test setting provider-specific kwargs.""" |
58 | | - chat = ChatOpenAI() |
59 | | - |
60 | | - chat.set_model_params(kwargs={"frequency_penalty": 0.1, "presence_penalty": 0.2}) |
61 | | - kwargs = getattr(chat, "_submit_input_kwargs", {}) |
62 | | - assert kwargs == {"frequency_penalty": 0.1, "presence_penalty": 0.2} |
63 | | - |
64 | | - # Reset kwargs to None |
65 | | - chat.set_model_params(kwargs=None) |
66 | | - kwargs = getattr(chat, "_submit_input_kwargs", {}) |
67 | | - assert kwargs is None |
68 | | - |
69 | | - |
70 | 58 | def test_set_model_params_all_parameters(): |
71 | 59 | """Test setting all supported standard parameters.""" |
72 | 60 | chat = ChatOpenAI() |
@@ -260,23 +248,6 @@ def test_set_model_params_updates_existing(): |
260 | 248 | assert params["top_p"] == 0.9 # New |
261 | 249 |
|
262 | 250 |
|
263 | | -def test_set_model_params_kwargs_replacement(): |
264 | | - """Test that kwargs are completely replaced, not merged.""" |
265 | | - chat = ChatOpenAI() |
266 | | - |
267 | | - # Set initial kwargs |
268 | | - chat.set_model_params(kwargs={"frequency_penalty": 0.1, "presence_penalty": 0.2}) |
269 | | - kwargs = getattr(chat, "_submit_input_kwargs", {}) |
270 | | - assert kwargs == {"frequency_penalty": 0.1, "presence_penalty": 0.2} |
271 | | - |
272 | | - # Set new kwargs - should completely replace |
273 | | - chat.set_model_params(kwargs={"seed": 42}) |
274 | | - kwargs = getattr(chat, "_submit_input_kwargs", {}) |
275 | | - assert kwargs == {"seed": 42} |
276 | | - assert "frequency_penalty" not in kwargs |
277 | | - assert "presence_penalty" not in kwargs |
278 | | - |
279 | | - |
280 | 251 | def test_set_model_params_invalid_temperature(): |
281 | 252 | """Test validation of temperature parameter ranges.""" |
282 | 253 | chat = ChatOpenAI() |
@@ -320,7 +291,6 @@ def test_set_model_params_empty_call(): |
320 | 291 |
|
321 | 292 | # Should not change anything |
322 | 293 | assert chat._standard_model_params == {} |
323 | | - assert chat._submit_input_kwargs is None |
324 | 294 |
|
325 | 295 |
|
326 | 296 | def test_set_model_params_type_validation(): |
@@ -388,27 +358,6 @@ def test_set_model_params_reset_specific_param(): |
388 | 358 | assert params["top_p"] == 0.95 |
389 | 359 |
|
390 | 360 |
|
391 | | -def test_model_params_kwargs_priority(): |
392 | | - """Test that submit_input_kwargs override model_params.""" |
393 | | - chat = ChatOpenAI() |
394 | | - |
395 | | - # Set model parameters |
396 | | - chat.set_model_params(temperature=0.1, max_tokens=100) |
397 | | - |
398 | | - # Set submit input kwargs that override some model params |
399 | | - chat.set_model_params(kwargs={"temperature": 0.5, "seed": 42}) |
400 | | - |
401 | | - # Verify that kwargs are set correctly |
402 | | - kwargs = getattr(chat, "_submit_input_kwargs", {}) |
403 | | - params = getattr(chat, "_standard_model_params", {}) |
404 | | - assert kwargs["temperature"] == 0.5 |
405 | | - assert kwargs["seed"] == 42 |
406 | | - |
407 | | - # Model params should still be stored |
408 | | - assert params["temperature"] == 0.1 |
409 | | - assert params["max_tokens"] == 100 |
410 | | - |
411 | | - |
412 | 361 | def test_is_present_function(): |
413 | 362 | """Test the is_present helper function used in set_model_params.""" |
414 | 363 | from chatlas._chat import is_present |
@@ -576,3 +525,22 @@ def test_parameter_validation_edge_cases(): |
576 | 525 |
|
577 | 526 | assert params["max_tokens"] == 1000000 |
578 | 527 | assert params["seed"] == 999999999 |
| 528 | + |
| 529 | + |
| 530 | +def test_chat_kwargs_with_model_params(): |
| 531 | + """Test that Chat kwargs work alongside set_model_params.""" |
| 532 | + chat = ChatOpenAI() |
| 533 | + |
| 534 | + # Set model parameters |
| 535 | + chat.set_model_params(temperature=0.7, max_tokens=100) |
| 536 | + |
| 537 | + # Set persistent chat kwargs |
| 538 | + chat.kwargs_chat = {"frequency_penalty": 0.1, "presence_penalty": 0.2} |
| 539 | + |
| 540 | + # Collect all kwargs for a chat call |
| 541 | + kwargs = chat._collect_all_kwargs({"presence_penalty": 0.3}) |
| 542 | + |
| 543 | + assert kwargs.get("temperature") == 0.7 |
| 544 | + assert kwargs.get("max_tokens") == 100 |
| 545 | + assert kwargs.get("frequency_penalty") == 0.1 |
| 546 | + assert kwargs.get("presence_penalty") == 0.3 # Should override chat.chat_kwargs |
0 commit comments