Skip to content

Commit 30d0125

Browse files
committed
Add basic image generation support; introduce new ToolBuiltIn class
1 parent 7411ce8 commit 30d0125

File tree

8 files changed

+194
-49
lines changed

8 files changed

+194
-49
lines changed

chatlas/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ._provider_portkey import ChatPortkey
3636
from ._provider_snowflake import ChatSnowflake
3737
from ._tokens import token_usage
38-
from ._tools import Tool, ToolRejectError
38+
from ._tools import Tool, ToolBuiltIn, ToolRejectError
3939
from ._turn import Turn
4040

4141
try:
@@ -84,6 +84,7 @@
8484
"Provider",
8585
"token_usage",
8686
"Tool",
87+
"ToolBuiltIn",
8788
"ToolRejectError",
8889
"Turn",
8990
"types",

chatlas/_chat.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from ._mcp_manager import MCPSessionManager
4949
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
5050
from ._tokens import compute_cost, get_token_pricing, tokens_log
51-
from ._tools import Tool, ToolRejectError
51+
from ._tools import Tool, ToolBuiltIn, ToolRejectError
5252
from ._turn import Turn, user_turn
5353
from ._typing_extensions import TypedDict, TypeGuard
5454
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
@@ -131,7 +131,7 @@ def __init__(
131131
self.system_prompt = system_prompt
132132
self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {}
133133

134-
self._tools: dict[str, Tool] = {}
134+
self._tools: dict[str, Tool | ToolBuiltIn] = {}
135135
self._on_tool_request_callbacks = CallbackManager()
136136
self._on_tool_result_callbacks = CallbackManager()
137137
self._current_display: Optional[MarkdownDisplay] = None
@@ -1850,7 +1850,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):
18501850

18511851
def register_tool(
18521852
self,
1853-
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
1853+
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn",
18541854
*,
18551855
force: bool = False,
18561856
name: Optional[str] = None,
@@ -1944,31 +1944,39 @@ def add(a: int, b: int) -> int:
19441944
ValueError
19451945
If a tool with the same name already exists and `force` is `False`.
19461946
"""
1947-
if isinstance(func, Tool):
1947+
if isinstance(func, ToolBuiltIn):
1948+
# ToolBuiltIn objects are stored directly without conversion
1949+
tool = func
1950+
tool_name = tool.name
1951+
elif isinstance(func, Tool):
19481952
name = name or func.name
19491953
annotations = annotations or func.annotations
19501954
if model is not None:
19511955
func = Tool.from_func(
19521956
func.func, name=name, model=model, annotations=annotations
19531957
)
19541958
func = func.func
1959+
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1960+
tool_name = tool.name
1961+
else:
1962+
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1963+
tool_name = tool.name
19551964

1956-
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1957-
if tool.name in self._tools and not force:
1965+
if tool_name in self._tools and not force:
19581966
raise ValueError(
1959-
f"Tool with name '{tool.name}' is already registered. "
1967+
f"Tool with name '{tool_name}' is already registered. "
19601968
"Set `force=True` to overwrite it."
19611969
)
1962-
self._tools[tool.name] = tool
1970+
self._tools[tool_name] = tool
19631971

1964-
def get_tools(self) -> list[Tool]:
1972+
def get_tools(self) -> list[Tool | ToolBuiltIn]:
19651973
"""
19661974
Get the list of registered tools.
19671975
19681976
Returns
19691977
-------
1970-
list[Tool]
1971-
A list of `Tool` instances that are currently registered with the chat.
1978+
list[Tool | ToolBuiltIn]
1979+
A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat.
19721980
"""
19731981
return list(self._tools.values())
19741982

@@ -2492,7 +2500,7 @@ def _submit_turns(
24922500
data_model: type[BaseModel] | None = None,
24932501
kwargs: Optional[SubmitInputArgsT] = None,
24942502
) -> Generator[str, None, None]:
2495-
if any(x._is_async for x in self._tools.values()):
2503+
if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()):
24962504
raise ValueError("Cannot use async tools in a synchronous chat")
24972505

24982506
def emit(text: str | Content):
@@ -2645,15 +2653,27 @@ def _collect_all_kwargs(
26452653

26462654
def _invoke_tool(self, request: ContentToolRequest):
26472655
tool = self._tools.get(request.name)
2648-
func = tool.func if tool is not None else None
26492656

2650-
if func is None:
2657+
if tool is None:
26512658
yield self._handle_tool_error_result(
26522659
request,
26532660
error=RuntimeError("Unknown tool."),
26542661
)
26552662
return
26562663

2664+
if isinstance(tool, ToolBuiltIn):
2665+
# Built-in tools are handled by the provider, not invoked directly
2666+
yield self._handle_tool_error_result(
2667+
request,
2668+
error=RuntimeError(
2669+
f"Built-in tool '{request.name}' cannot be invoked directly. "
2670+
"It should be handled by the provider."
2671+
),
2672+
)
2673+
return
2674+
2675+
func = tool.func
2676+
26572677
# First, invoke the request callbacks. If a ToolRejectError is raised,
26582678
# treat it like a tool failure (i.e., gracefully handle it).
26592679
result: ContentToolResult | None = None
@@ -2701,6 +2721,17 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
27012721
)
27022722
return
27032723

2724+
if isinstance(tool, ToolBuiltIn):
2725+
# Built-in tools are handled by the provider, not invoked directly
2726+
yield self._handle_tool_error_result(
2727+
request,
2728+
error=RuntimeError(
2729+
f"Built-in tool '{request.name}' cannot be invoked directly. "
2730+
"It should be handled by the provider."
2731+
),
2732+
)
2733+
return
2734+
27042735
if tool._is_async:
27052736
func = tool.func
27062737
else:

chatlas/_content.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ._typing_extensions import TypedDict
1212

1313
if TYPE_CHECKING:
14-
from ._tools import Tool
14+
from ._tools import Tool, ToolBuiltIn
1515

1616

1717
class ToolAnnotations(TypedDict, total=False):
@@ -104,15 +104,28 @@ class ToolInfo(BaseModel):
104104
annotations: Optional[ToolAnnotations] = None
105105

106106
@classmethod
107-
def from_tool(cls, tool: "Tool") -> "ToolInfo":
108-
"""Create a ToolInfo from a Tool instance."""
109-
func_schema = tool.schema["function"]
110-
return cls(
111-
name=tool.name,
112-
description=func_schema.get("description", ""),
113-
parameters=func_schema.get("parameters", {}),
114-
annotations=tool.annotations,
115-
)
107+
def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo":
108+
"""Create a ToolInfo from a Tool or ToolBuiltIn instance."""
109+
from ._tools import ToolBuiltIn
110+
111+
if isinstance(tool, ToolBuiltIn):
112+
# For built-in tools, extract info from the definition
113+
defn = tool.definition
114+
return cls(
115+
name=tool.name,
116+
description=defn.get("description", ""),
117+
parameters=defn.get("parameters", {}),
118+
annotations=None,
119+
)
120+
else:
121+
# For regular tools, extract from schema
122+
func_schema = tool.schema["function"]
123+
return cls(
124+
name=tool.name,
125+
description=func_schema.get("description", ""),
126+
parameters=func_schema.get("parameters", {}),
127+
annotations=tool.annotations,
128+
)
116129

117130

118131
ContentTypeEnum = Literal[
@@ -247,6 +260,22 @@ def __str__(self):
247260
def _repr_markdown_(self):
248261
return self.__str__()
249262

263+
def _repr_png_(self):
264+
"""Display PNG images directly in Jupyter notebooks."""
265+
if self.image_content_type == "image/png" and self.data:
266+
import base64
267+
268+
return base64.b64decode(self.data)
269+
return None
270+
271+
def _repr_jpeg_(self):
272+
"""Display JPEG images directly in Jupyter notebooks."""
273+
if self.image_content_type == "image/jpeg" and self.data:
274+
import base64
275+
276+
return base64.b64decode(self.data)
277+
return None
278+
250279
def __repr__(self, indent: int = 0):
251280
n_bytes = len(self.data) if self.data else 0
252281
return (

chatlas/_provider.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic import BaseModel
1717

1818
from ._content import Content
19-
from ._tools import Tool
19+
from ._tools import Tool, ToolBuiltIn
2020
from ._turn import Turn
2121
from ._typing_extensions import NotRequired, TypedDict
2222

@@ -162,7 +162,7 @@ def chat_perform(
162162
*,
163163
stream: Literal[False],
164164
turns: list[Turn],
165-
tools: dict[str, Tool],
165+
tools: dict[str, Tool | ToolBuiltIn],
166166
data_model: Optional[type[BaseModel]],
167167
kwargs: SubmitInputArgsT,
168168
) -> ChatCompletionT: ...
@@ -174,7 +174,7 @@ def chat_perform(
174174
*,
175175
stream: Literal[True],
176176
turns: list[Turn],
177-
tools: dict[str, Tool],
177+
tools: dict[str, Tool | ToolBuiltIn],
178178
data_model: Optional[type[BaseModel]],
179179
kwargs: SubmitInputArgsT,
180180
) -> Iterable[ChatCompletionChunkT]: ...
@@ -185,7 +185,7 @@ def chat_perform(
185185
*,
186186
stream: bool,
187187
turns: list[Turn],
188-
tools: dict[str, Tool],
188+
tools: dict[str, Tool | ToolBuiltIn],
189189
data_model: Optional[type[BaseModel]],
190190
kwargs: SubmitInputArgsT,
191191
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
@@ -197,7 +197,7 @@ async def chat_perform_async(
197197
*,
198198
stream: Literal[False],
199199
turns: list[Turn],
200-
tools: dict[str, Tool],
200+
tools: dict[str, Tool | ToolBuiltIn],
201201
data_model: Optional[type[BaseModel]],
202202
kwargs: SubmitInputArgsT,
203203
) -> ChatCompletionT: ...
@@ -209,7 +209,7 @@ async def chat_perform_async(
209209
*,
210210
stream: Literal[True],
211211
turns: list[Turn],
212-
tools: dict[str, Tool],
212+
tools: dict[str, Tool | ToolBuiltIn],
213213
data_model: Optional[type[BaseModel]],
214214
kwargs: SubmitInputArgsT,
215215
) -> AsyncIterable[ChatCompletionChunkT]: ...
@@ -220,7 +220,7 @@ async def chat_perform_async(
220220
*,
221221
stream: bool,
222222
turns: list[Turn],
223-
tools: dict[str, Tool],
223+
tools: dict[str, Tool | ToolBuiltIn],
224224
data_model: Optional[type[BaseModel]],
225225
kwargs: SubmitInputArgsT,
226226
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
@@ -259,15 +259,15 @@ def value_tokens(
259259
def token_count(
260260
self,
261261
*args: Content | str,
262-
tools: dict[str, Tool],
262+
tools: dict[str, Tool | ToolBuiltIn],
263263
data_model: Optional[type[BaseModel]],
264264
) -> int: ...
265265

266266
@abstractmethod
267267
async def token_count_async(
268268
self,
269269
*args: Content | str,
270-
tools: dict[str, Tool],
270+
tools: dict[str, Tool | ToolBuiltIn],
271271
data_model: Optional[type[BaseModel]],
272272
) -> int: ...
273273

chatlas/_provider_google.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,17 +309,25 @@ def _chat_perform_args(
309309
config.response_mime_type = "application/json"
310310

311311
if tools:
312-
config.tools = [
313-
GoogleTool(
314-
function_declarations=[
312+
from ._tools import ToolBuiltIn
313+
314+
function_declarations = []
315+
for tool in tools.values():
316+
if isinstance(tool, ToolBuiltIn):
317+
# For built-in tools, pass the raw definition through
318+
# This allows provider-specific tools like image generation
319+
# Note: Google's API expects these in a specific format
320+
continue # Built-in tools are not yet fully supported for Google
321+
else:
322+
function_declarations.append(
315323
FunctionDeclaration.from_callable(
316324
client=self._client._api_client,
317325
callable=tool.func,
318326
)
319-
for tool in tools.values()
320-
]
321-
)
322-
]
327+
)
328+
329+
if function_declarations:
330+
config.tools = [GoogleTool(function_declarations=function_declarations)]
323331

324332
kwargs_full["config"] = config
325333

@@ -552,6 +560,20 @@ def _as_turn(
552560
),
553561
)
554562
)
563+
inline_data = part.get("inlineData") or part.get("inline_data")
564+
if inline_data:
565+
# Handle image generation responses
566+
mime_type = inline_data.get("mimeType") or inline_data.get("mime_type")
567+
data = inline_data.get("data")
568+
if mime_type and data:
569+
# Ensure data is a string (should be base64 encoded)
570+
data_str = data if isinstance(data, str) else str(data)
571+
contents.append(
572+
ContentImageInline(
573+
image_content_type=mime_type, # type: ignore
574+
data=data_str,
575+
)
576+
)
555577

556578
if isinstance(finish_reason, FinishReason):
557579
finish_reason = finish_reason.name

0 commit comments

Comments
 (0)