|
48 | 48 | from ._mcp_manager import MCPSessionManager |
49 | 49 | from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT |
50 | 50 | from ._tokens import compute_cost, get_token_pricing, tokens_log |
51 | | -from ._tools import Tool, ToolRejectError |
| 51 | +from ._tools import Tool, ToolBuiltIn, ToolRejectError |
52 | 52 | from ._turn import Turn, user_turn |
53 | 53 | from ._typing_extensions import TypedDict, TypeGuard |
54 | 54 | from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async |
@@ -131,7 +131,7 @@ def __init__( |
131 | 131 | self.system_prompt = system_prompt |
132 | 132 | self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {} |
133 | 133 |
|
134 | | - self._tools: dict[str, Tool] = {} |
| 134 | + self._tools: dict[str, Tool | ToolBuiltIn] = {} |
135 | 135 | self._on_tool_request_callbacks = CallbackManager() |
136 | 136 | self._on_tool_result_callbacks = CallbackManager() |
137 | 137 | self._current_display: Optional[MarkdownDisplay] = None |
@@ -1850,7 +1850,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None): |
1850 | 1850 |
|
1851 | 1851 | def register_tool( |
1852 | 1852 | self, |
1853 | | - func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool, |
| 1853 | + func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn", |
1854 | 1854 | *, |
1855 | 1855 | force: bool = False, |
1856 | 1856 | name: Optional[str] = None, |
@@ -1944,31 +1944,39 @@ def add(a: int, b: int) -> int: |
1944 | 1944 | ValueError |
1945 | 1945 | If a tool with the same name already exists and `force` is `False`. |
1946 | 1946 | """ |
1947 | | - if isinstance(func, Tool): |
| 1947 | + if isinstance(func, ToolBuiltIn): |
| 1948 | + # ToolBuiltIn objects are stored directly without conversion |
| 1949 | + tool = func |
| 1950 | + tool_name = tool.name |
| 1951 | + elif isinstance(func, Tool): |
1948 | 1952 | name = name or func.name |
1949 | 1953 | annotations = annotations or func.annotations |
1950 | 1954 | if model is not None: |
1951 | 1955 | func = Tool.from_func( |
1952 | 1956 | func.func, name=name, model=model, annotations=annotations |
1953 | 1957 | ) |
1954 | 1958 | func = func.func |
| 1959 | + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
| 1960 | + tool_name = tool.name |
| 1961 | + else: |
| 1962 | + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
| 1963 | + tool_name = tool.name |
1955 | 1964 |
|
1956 | | - tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
1957 | | - if tool.name in self._tools and not force: |
| 1965 | + if tool_name in self._tools and not force: |
1958 | 1966 | raise ValueError( |
1959 | | - f"Tool with name '{tool.name}' is already registered. " |
| 1967 | + f"Tool with name '{tool_name}' is already registered. " |
1960 | 1968 | "Set `force=True` to overwrite it." |
1961 | 1969 | ) |
1962 | | - self._tools[tool.name] = tool |
| 1970 | + self._tools[tool_name] = tool |
1963 | 1971 |
|
1964 | | - def get_tools(self) -> list[Tool]: |
| 1972 | + def get_tools(self) -> list[Tool | ToolBuiltIn]: |
1965 | 1973 | """ |
1966 | 1974 | Get the list of registered tools. |
1967 | 1975 |
|
1968 | 1976 | Returns |
1969 | 1977 | ------- |
1970 | | - list[Tool] |
1971 | | - A list of `Tool` instances that are currently registered with the chat. |
| 1978 | + list[Tool | ToolBuiltIn] |
| 1979 | + A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat. |
1972 | 1980 | """ |
1973 | 1981 | return list(self._tools.values()) |
1974 | 1982 |
|
@@ -2492,7 +2500,7 @@ def _submit_turns( |
2492 | 2500 | data_model: type[BaseModel] | None = None, |
2493 | 2501 | kwargs: Optional[SubmitInputArgsT] = None, |
2494 | 2502 | ) -> Generator[str, None, None]: |
2495 | | - if any(x._is_async for x in self._tools.values()): |
| 2503 | + if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()): |
2496 | 2504 | raise ValueError("Cannot use async tools in a synchronous chat") |
2497 | 2505 |
|
2498 | 2506 | def emit(text: str | Content): |
@@ -2645,15 +2653,27 @@ def _collect_all_kwargs( |
2645 | 2653 |
|
2646 | 2654 | def _invoke_tool(self, request: ContentToolRequest): |
2647 | 2655 | tool = self._tools.get(request.name) |
2648 | | - func = tool.func if tool is not None else None |
2649 | 2656 |
|
2650 | | - if func is None: |
| 2657 | + if tool is None: |
2651 | 2658 | yield self._handle_tool_error_result( |
2652 | 2659 | request, |
2653 | 2660 | error=RuntimeError("Unknown tool."), |
2654 | 2661 | ) |
2655 | 2662 | return |
2656 | 2663 |
|
| 2664 | + if isinstance(tool, ToolBuiltIn): |
| 2665 | + # Built-in tools are handled by the provider, not invoked directly |
| 2666 | + yield self._handle_tool_error_result( |
| 2667 | + request, |
| 2668 | + error=RuntimeError( |
| 2669 | + f"Built-in tool '{request.name}' cannot be invoked directly. " |
| 2670 | + "It should be handled by the provider." |
| 2671 | + ), |
| 2672 | + ) |
| 2673 | + return |
| 2674 | + |
| 2675 | + func = tool.func |
| 2676 | + |
2657 | 2677 | # First, invoke the request callbacks. If a ToolRejectError is raised, |
2658 | 2678 | # treat it like a tool failure (i.e., gracefully handle it). |
2659 | 2679 | result: ContentToolResult | None = None |
@@ -2701,6 +2721,17 @@ async def _invoke_tool_async(self, request: ContentToolRequest): |
2701 | 2721 | ) |
2702 | 2722 | return |
2703 | 2723 |
|
| 2724 | + if isinstance(tool, ToolBuiltIn): |
| 2725 | + # Built-in tools are handled by the provider, not invoked directly |
| 2726 | + yield self._handle_tool_error_result( |
| 2727 | + request, |
| 2728 | + error=RuntimeError( |
| 2729 | + f"Built-in tool '{request.name}' cannot be invoked directly. " |
| 2730 | + "It should be handled by the provider." |
| 2731 | + ), |
| 2732 | + ) |
| 2733 | + return |
| 2734 | + |
2704 | 2735 | if tool._is_async: |
2705 | 2736 | func = tool.func |
2706 | 2737 | else: |
|
0 commit comments