Skip to content

Commit 7526e4d

Browse files
kiberguscopybara-github
authored andcommitted
feat: Make genai.Part constructible from PartUnionDict.
PiperOrigin-RevId: 827985584
1 parent 856789a commit 7526e4d

File tree

3 files changed

+131
-26
lines changed

3 files changed

+131
-26
lines changed

google/genai/_transformers.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -390,31 +390,7 @@ def t_audio_blob(blob: types.BlobOrDict) -> types.Blob:
390390
def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
391391
if part is None:
392392
raise ValueError('content part is required.')
393-
if isinstance(part, str):
394-
return types.Part(text=part)
395-
if _is_duck_type_of(part, types.File):
396-
if not part.uri or not part.mime_type: # type: ignore[union-attr]
397-
raise ValueError('file uri and mime_type are required.')
398-
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type) # type: ignore[union-attr]
399-
if isinstance(part, dict):
400-
try:
401-
return types.Part.model_validate(part)
402-
except pydantic.ValidationError:
403-
return types.Part(file_data=types.FileData.model_validate(part))
404-
if _is_duck_type_of(part, types.Part):
405-
return part # type: ignore[return-value]
406-
407-
if 'image' in part.__class__.__name__.lower():
408-
try:
409-
import PIL.Image
410-
411-
PIL_Image = PIL.Image.Image
412-
except ImportError:
413-
PIL_Image = None
414-
415-
if PIL_Image is not None and isinstance(part, PIL_Image):
416-
return types.Part(inline_data=pil_to_blob(part))
417-
raise ValueError(f'Unsupported content part type: {type(part)}')
393+
return types.Part(part)
418394

419395

420396
def t_parts(

google/genai/tests/types/test_types.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import sys
1919
import typing
2020
from typing import Optional, assert_never
21+
import PIL.Image
2122
import pydantic
2223
import pytest
2324
from ... import types
@@ -301,6 +302,58 @@ def test_factory_method_from_mcp_call_tool_function_response_embedded_resource()
301302
assert isinstance(my_function_response, types.FunctionResponse)
302303

303304

305+
def test_part_constructor_with_string_value():
306+
part = types.Part('hello')
307+
assert part.text == 'hello'
308+
assert part.file_data is None
309+
assert part.inline_data is None
310+
311+
312+
def test_part_constructor_with_part_value():
313+
other_part = types.Part(text='hello from other part')
314+
part = types.Part(other_part)
315+
assert part.text == 'hello from other part'
316+
317+
318+
def test_part_constructor_with_part_dict_value():
319+
part = types.Part({'text': 'hello from dict'})
320+
assert part.text == 'hello from dict'
321+
322+
323+
def test_part_constructor_with_file_data_dict_value():
324+
part = types.Part(
325+
{'file_uri': 'gs://my-bucket/file-data', 'mime_type': 'text/plain'}
326+
)
327+
assert part.file_data.file_uri == 'gs://my-bucket/file-data'
328+
assert part.file_data.mime_type == 'text/plain'
329+
330+
331+
def test_part_constructor_with_kwargs_and_value_fails():
332+
with pytest.raises(
333+
ValueError, match='Positional and keyword arguments can not be combined'
334+
):
335+
types.Part('hello', text='world')
336+
337+
338+
def test_part_constructor_with_file_value():
339+
f = types.File(
340+
uri='gs://my-bucket/my-file',
341+
mime_type='text/plain',
342+
display_name='test file',
343+
)
344+
part = types.Part(f)
345+
assert part.file_data.file_uri == 'gs://my-bucket/my-file'
346+
assert part.file_data.mime_type == 'text/plain'
347+
assert part.file_data.display_name == 'test file'
348+
349+
350+
def test_part_constructor_with_pil_image():
351+
img = PIL.Image.new('RGB', (1, 1), color='red')
352+
part = types.Part(img)
353+
assert part.inline_data.mime_type == 'image/jpeg'
354+
assert isinstance(part.inline_data.data, bytes)
355+
356+
304357
class FakeClient:
305358

306359
def __init__(self, vertexai=False) -> None:

google/genai/types.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
import datetime
2020
from enum import Enum, EnumMeta
2121
import inspect
22+
import io
2223
import json
2324
import logging
2425
import sys
2526
import types as builtin_types
2627
import typing
27-
from typing import Any, Callable, Literal, Optional, Sequence, Union, _UnionGenericAlias # type: ignore
28+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union, _UnionGenericAlias # type: ignore
2829
import pydantic
2930
from pydantic import ConfigDict, Field, PrivateAttr, model_validator
3031
from typing_extensions import Self, TypedDict
@@ -1376,6 +1377,81 @@ class Part(_common.BaseModel):
13761377
description="""Optional. Video metadata. The metadata should only be specified while the video data is presented in inline_data or file_data.""",
13771378
)
13781379

1380+
def __init__(
1381+
self,
1382+
value: Optional['PartUnionDict'] = None,
1383+
/,
1384+
*,
1385+
video_metadata: Optional[VideoMetadata] = None,
1386+
thought: Optional[bool] = None,
1387+
inline_data: Optional[Blob] = None,
1388+
file_data: Optional[FileData] = None,
1389+
thought_signature: Optional[bytes] = None,
1390+
function_call: Optional[FunctionCall] = None,
1391+
code_execution_result: Optional[CodeExecutionResult] = None,
1392+
executable_code: Optional[ExecutableCode] = None,
1393+
function_response: Optional[FunctionResponse] = None,
1394+
text: Optional[str] = None,
1395+
# Pydantic allows CamelCase in addition to snake_case attribute
1396+
# names. kwargs here catch these aliases.
1397+
**kwargs: Any,
1398+
):
1399+
part_dict = dict(
1400+
video_metadata=video_metadata,
1401+
thought=thought,
1402+
inline_data=inline_data,
1403+
file_data=file_data,
1404+
thought_signature=thought_signature,
1405+
function_call=function_call,
1406+
code_execution_result=code_execution_result,
1407+
executable_code=executable_code,
1408+
function_response=function_response,
1409+
text=text,
1410+
)
1411+
part_dict = {k: v for k, v in part_dict.items() if v is not None}
1412+
1413+
if part_dict and value is not None:
1414+
raise ValueError(
1415+
'Positional and keyword arguments can not be combined when '
1416+
'initializing a Part.'
1417+
)
1418+
1419+
if value is None:
1420+
pass
1421+
elif isinstance(value, str):
1422+
part_dict['text'] = value
1423+
elif isinstance(value, File):
1424+
if not value.uri or not value.mime_type:
1425+
raise ValueError('file uri and mime_type are required.')
1426+
part_dict['file_data'] = FileData(
1427+
file_uri=value.uri,
1428+
mime_type=value.mime_type,
1429+
display_name=value.display_name,
1430+
)
1431+
elif isinstance(value, dict):
1432+
try:
1433+
Part.model_validate(value)
1434+
part_dict.update(value) # type: ignore[arg-type]
1435+
except pydantic.ValidationError:
1436+
part_dict['file_data'] = FileData.model_validate(value)
1437+
elif isinstance(value, Part):
1438+
part_dict.update(value.dict())
1439+
elif 'image' in value.__class__.__name__.lower():
1440+
# PIL.Image case.
1441+
1442+
suffix = value.format.lower() if value.format else 'jpeg'
1443+
mimetype = f'image/{suffix}'
1444+
bytes_io = io.BytesIO()
1445+
value.save(bytes_io, suffix.upper())
1446+
1447+
part_dict['inline_data'] = Blob(
1448+
data=bytes_io.getvalue(), mime_type=mimetype
1449+
)
1450+
else:
1451+
raise ValueError(f'Unsupported content part type: {type(value)}')
1452+
1453+
super().__init__(**part_dict, **kwargs)
1454+
13791455
def as_image(self) -> Optional['Image']:
13801456
"""Returns the part as a PIL Image, or None if the part is not an image."""
13811457
if not self.inline_data:

0 commit comments

Comments
 (0)