Skip to content

Commit c4b4e1f

Browse files
committed
Cleanup
1 parent 2acd2cb commit c4b4e1f

File tree

9 files changed

+101
-131
lines changed

9 files changed

+101
-131
lines changed

chatlas/_chat.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,7 +1880,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):
18801880

18811881
def register_tool(
18821882
self,
1883-
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn",
1883+
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | ToolBuiltIn,
18841884
*,
18851885
force: bool = False,
18861886
name: Optional[str] = None,
@@ -1974,11 +1974,7 @@ def add(a: int, b: int) -> int:
19741974
ValueError
19751975
If a tool with the same name already exists and `force` is `False`.
19761976
"""
1977-
if isinstance(func, ToolBuiltIn):
1978-
# ToolBuiltIn objects are stored directly without conversion
1979-
tool = func
1980-
tool_name = tool.name
1981-
elif isinstance(func, Tool):
1977+
if isinstance(func, Tool):
19821978
name = name or func.name
19831979
annotations = annotations or func.annotations
19841980
if model is not None:
@@ -1987,17 +1983,20 @@ def add(a: int, b: int) -> int:
19871983
)
19881984
func = func.func
19891985
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1990-
tool_name = tool.name
19911986
else:
1992-
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1993-
tool_name = tool.name
1987+
if isinstance(func, ToolBuiltIn):
1988+
tool = func
1989+
else:
1990+
tool = Tool.from_func(
1991+
func, name=name, model=model, annotations=annotations
1992+
)
19941993

1995-
if tool_name in self._tools and not force:
1994+
if tool.name in self._tools and not force:
19961995
raise ValueError(
1997-
f"Tool with name '{tool_name}' is already registered. "
1996+
f"Tool with name '{tool.name}' is already registered. "
19981997
"Set `force=True` to overwrite it."
19991998
)
2000-
self._tools[tool_name] = tool
1999+
self._tools[tool.name] = tool
20012000

20022001
def get_tools(self) -> list[Tool | ToolBuiltIn]:
20032002
"""
@@ -2530,7 +2529,7 @@ def _submit_turns(
25302529
data_model: type[BaseModel] | None = None,
25312530
kwargs: Optional[SubmitInputArgsT] = None,
25322531
) -> Generator[str, None, None]:
2533-
if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()):
2532+
if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()):
25342533
raise ValueError("Cannot use async tools in a synchronous chat")
25352534

25362535
def emit(text: str | Content):
@@ -2700,7 +2699,6 @@ def _invoke_tool(self, request: ContentToolRequest):
27002699
return
27012700

27022701
if isinstance(tool, ToolBuiltIn):
2703-
# Built-in tools are handled by the provider, not invoked directly
27042702
yield self._handle_tool_error_result(
27052703
request,
27062704
error=RuntimeError(
@@ -2710,8 +2708,6 @@ def _invoke_tool(self, request: ContentToolRequest):
27102708
)
27112709
return
27122710

2713-
func = tool.func
2714-
27152711
# First, invoke the request callbacks. If a ToolRejectError is raised,
27162712
# treat it like a tool failure (i.e., gracefully handle it).
27172713
result: ContentToolResult | None = None
@@ -2723,9 +2719,9 @@ def _invoke_tool(self, request: ContentToolRequest):
27232719

27242720
try:
27252721
if isinstance(request.arguments, dict):
2726-
res = func(**request.arguments)
2722+
res = tool.func(**request.arguments)
27272723
else:
2728-
res = func(request.arguments)
2724+
res = tool.func(request.arguments)
27292725

27302726
# Normalize res as a generator of results.
27312727
if not inspect.isgenerator(res):
@@ -2760,7 +2756,6 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
27602756
return
27612757

27622758
if isinstance(tool, ToolBuiltIn):
2763-
# Built-in tools are handled by the provider, not invoked directly
27642759
yield self._handle_tool_error_result(
27652760
request,
27662761
error=RuntimeError(
@@ -2770,11 +2765,6 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
27702765
)
27712766
return
27722767

2773-
if tool._is_async:
2774-
func = tool.func
2775-
else:
2776-
func = wrap_async(tool.func)
2777-
27782768
# First, invoke the request callbacks. If a ToolRejectError is raised,
27792769
# treat it like a tool failure (i.e., gracefully handle it).
27802770
result: ContentToolResult | None = None
@@ -2784,6 +2774,11 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
27842774
yield self._handle_tool_error_result(request, e)
27852775
return
27862776

2777+
if tool._is_async:
2778+
func = tool.func
2779+
else:
2780+
func = wrap_async(tool.func)
2781+
27872782
# Invoke the tool (if it hasn't been rejected).
27882783
try:
27892784
if isinstance(request.arguments, dict):

chatlas/_content.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,7 @@ def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo":
109109
from ._tools import ToolBuiltIn
110110

111111
if isinstance(tool, ToolBuiltIn):
112-
# For built-in tools, extract info from the definition
113-
defn = tool.definition
114-
return cls(
115-
name=tool.name,
116-
description=defn.get("description", ""),
117-
parameters=defn.get("parameters", {}),
118-
annotations=None,
119-
)
112+
return cls(name=tool.name, description=tool.name, parameters={})
120113
else:
121114
# For regular tools, extract from schema
122115
func_schema = tool.schema["function"]

chatlas/_mcp_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class SessionInfo(ABC):
2323

2424
# Primary derived attributes
2525
session: ClientSession | None = None
26-
tools: dict[str, Tool] = field(default_factory=dict)
26+
tools: dict[str, Tool | ToolBuiltIn] = field(default_factory=dict)
2727

2828
# Background task management
2929
ready_event: asyncio.Event = field(default_factory=asyncio.Event)
@@ -74,7 +74,7 @@ async def request_tools(self) -> None:
7474
tool_names = tool_names.difference(exclude)
7575

7676
# Apply namespace and convert to chatlas.Tool instances
77-
self_tools: dict[str, Tool] = {}
77+
self_tools: dict[str, Tool | ToolBuiltIn] = {}
7878
for tool in response.tools:
7979
if tool.name not in tool_names:
8080
continue

chatlas/_provider_anthropic.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
)
1616

1717
import orjson
18-
from openai.types.chat import ChatCompletionToolParam
1918
from pydantic import BaseModel
2019

2120
from ._chat import Chat
@@ -38,7 +37,7 @@
3837
StandardModelParams,
3938
)
4039
from ._tokens import get_token_pricing
41-
from ._tools import Tool, basemodel_to_param_schema
40+
from ._tools import Tool, ToolBuiltIn, basemodel_to_param_schema
4241
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn
4342
from ._utils import split_http_client_kwargs
4443

@@ -48,7 +47,7 @@
4847
MessageParam,
4948
RawMessageStreamEvent,
5049
TextBlock,
51-
ToolParam,
50+
ToolUnionParam,
5251
ToolUseBlock,
5352
)
5453
from anthropic.types.cache_control_ephemeral_param import CacheControlEphemeralParam
@@ -304,7 +303,7 @@ def chat_perform(
304303
*,
305304
stream: Literal[False],
306305
turns: list[Turn],
307-
tools: dict[str, Tool],
306+
tools: dict[str, Tool | ToolBuiltIn],
308307
data_model: Optional[type[BaseModel]] = None,
309308
kwargs: Optional["SubmitInputArgs"] = None,
310309
): ...
@@ -315,7 +314,7 @@ def chat_perform(
315314
*,
316315
stream: Literal[True],
317316
turns: list[Turn],
318-
tools: dict[str, Tool],
317+
tools: dict[str, Tool | ToolBuiltIn],
319318
data_model: Optional[type[BaseModel]] = None,
320319
kwargs: Optional["SubmitInputArgs"] = None,
321320
): ...
@@ -325,7 +324,7 @@ def chat_perform(
325324
*,
326325
stream: bool,
327326
turns: list[Turn],
328-
tools: dict[str, Tool],
327+
tools: dict[str, Tool | ToolBuiltIn],
329328
data_model: Optional[type[BaseModel]] = None,
330329
kwargs: Optional["SubmitInputArgs"] = None,
331330
):
@@ -338,7 +337,7 @@ async def chat_perform_async(
338337
*,
339338
stream: Literal[False],
340339
turns: list[Turn],
341-
tools: dict[str, Tool],
340+
tools: dict[str, Tool | ToolBuiltIn],
342341
data_model: Optional[type[BaseModel]] = None,
343342
kwargs: Optional["SubmitInputArgs"] = None,
344343
): ...
@@ -349,7 +348,7 @@ async def chat_perform_async(
349348
*,
350349
stream: Literal[True],
351350
turns: list[Turn],
352-
tools: dict[str, Tool],
351+
tools: dict[str, Tool | ToolBuiltIn],
353352
data_model: Optional[type[BaseModel]] = None,
354353
kwargs: Optional["SubmitInputArgs"] = None,
355354
): ...
@@ -359,7 +358,7 @@ async def chat_perform_async(
359358
*,
360359
stream: bool,
361360
turns: list[Turn],
362-
tools: dict[str, Tool],
361+
tools: dict[str, Tool | ToolBuiltIn],
363362
data_model: Optional[type[BaseModel]] = None,
364363
kwargs: Optional["SubmitInputArgs"] = None,
365364
):
@@ -370,12 +369,12 @@ def _chat_perform_args(
370369
self,
371370
stream: bool,
372371
turns: list[Turn],
373-
tools: dict[str, Tool],
372+
tools: dict[str, Tool | ToolBuiltIn],
374373
data_model: Optional[type[BaseModel]] = None,
375374
kwargs: Optional["SubmitInputArgs"] = None,
376375
) -> "SubmitInputArgs":
377376
tool_schemas = [
378-
self._anthropic_tool_schema(tool.schema) for tool in tools.values()
377+
self._anthropic_tool_schema(tool) for tool in tools.values()
379378
]
380379

381380
# If data extraction is requested, add a "mock" tool with parameters inferred from the data model
@@ -395,7 +394,7 @@ def _structured_tool_call(**kwargs: Any):
395394
},
396395
}
397396

398-
tool_schemas.append(self._anthropic_tool_schema(data_model_tool.schema))
397+
tool_schemas.append(self._anthropic_tool_schema(data_model_tool))
399398

400399
if stream:
401400
stream = False
@@ -497,7 +496,7 @@ def value_tokens(self, completion):
497496
def token_count(
498497
self,
499498
*args: Content | str,
500-
tools: dict[str, Tool],
499+
tools: dict[str, Tool | ToolBuiltIn],
501500
data_model: Optional[type[BaseModel]],
502501
) -> int:
503502
kwargs = self._token_count_args(
@@ -511,7 +510,7 @@ def token_count(
511510
async def token_count_async(
512511
self,
513512
*args: Content | str,
514-
tools: dict[str, Tool],
513+
tools: dict[str, Tool | ToolBuiltIn],
515514
data_model: Optional[type[BaseModel]],
516515
) -> int:
517516
kwargs = self._token_count_args(
@@ -525,7 +524,7 @@ async def token_count_async(
525524
def _token_count_args(
526525
self,
527526
*args: Content | str,
528-
tools: dict[str, Tool],
527+
tools: dict[str, Tool | ToolBuiltIn],
529528
data_model: Optional[type[BaseModel]],
530529
) -> dict[str, Any]:
531530
turn = user_turn(*args)
@@ -655,11 +654,14 @@ def _as_content_block(content: Content) -> "ContentBlockParam":
655654
raise ValueError(f"Unknown content type: {type(content)}")
656655

657656
@staticmethod
658-
def _anthropic_tool_schema(schema: "ChatCompletionToolParam") -> "ToolParam":
659-
fn = schema["function"]
657+
def _anthropic_tool_schema(tool: "Tool | ToolBuiltIn") -> "ToolUnionParam":
658+
if isinstance(tool, ToolBuiltIn):
659+
return tool.definition # type: ignore
660+
661+
fn = tool.schema["function"]
660662
name = fn["name"]
661663

662-
res: "ToolParam" = {
664+
res: "ToolUnionParam" = {
663665
"name": name,
664666
"input_schema": {
665667
"type": "object",

0 commit comments

Comments
 (0)