diff --git a/dspy/__init__.py b/dspy/__init__.py index ea4c75a862..01aad36757 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -6,7 +6,7 @@ from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, File, History, Type, Tool, ToolCalls, Code # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.syncify import syncify diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index 1dea6da47a..fded25398f 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -2,7 +2,7 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from dspy.adapters.two_step_adapter import TwoStepAdapter -from dspy.adapters.types import Audio, Code, History, Image, Tool, ToolCalls, Type +from dspy.adapters.types import Audio, Code, File, History, Image, Tool, ToolCalls, Type from dspy.adapters.xml_adapter import XMLAdapter __all__ = [ @@ -12,6 +12,7 @@ "History", "Image", "Audio", + "File", "Code", "JSONAdapter", "XMLAdapter", diff --git a/dspy/adapters/types/__init__.py b/dspy/adapters/types/__init__.py index 11b9faee1b..eb1481c862 100644 --- a/dspy/adapters/types/__init__.py +++ b/dspy/adapters/types/__init__.py @@ -1,8 +1,9 @@ from dspy.adapters.types.audio import Audio from dspy.adapters.types.base_type import Type from dspy.adapters.types.code import Code +from dspy.adapters.types.file import File from dspy.adapters.types.history import History from dspy.adapters.types.image import Image from dspy.adapters.types.tool import Tool, ToolCalls -__all__ = ["History", "Image", "Audio", "Type", "Tool", "ToolCalls", "Code"] +__all__ = ["History", "Image", "Audio", "File", "Type", "Tool", "ToolCalls", "Code"] diff --git a/dspy/adapters/types/file.py b/dspy/adapters/types/file.py new file mode 100644 index 0000000000..34bc2a0110 --- /dev/null +++ b/dspy/adapters/types/file.py @@ -0,0 +1,178 @@ +import base64 +import mimetypes +import os +from typing import Any + +import pydantic + +from dspy.adapters.types.base_type import Type + + +class File(Type): + """A file input type for DSPy. + See https://platform.openai.com/docs/api-reference/chat/create#chat_create-messages-user_message-content-array_of_content_parts-file_content_part-file for specification. + + The file_data field should be a data URI with the format: + data:;base64, + + Example: + ```python + import dspy + + class QA(dspy.Signature): + file: dspy.File = dspy.InputField() + summary = dspy.OutputField() + program = dspy.Predict(QA) + result = program(file=dspy.File.from_path("./research.pdf")) + print(result.summary) + ``` + """ + + file_data: str | None = None + file_id: str | None = None + filename: str | None = None + + model_config = pydantic.ConfigDict( + frozen=True, + str_strip_whitespace=True, + validate_assignment=True, + extra="forbid", + ) + + @pydantic.model_validator(mode="before") + @classmethod + def validate_input(cls, values: Any) -> Any: + if isinstance(values, cls): + return { + "file_data": values.file_data, + "file_id": values.file_id, + "filename": values.filename, + } + + if isinstance(values, dict): + if "file_data" in values or "file_id" in values or "filename" in values: + return values + raise ValueError("Value of `dspy.File` must contain at least one of: file_data, file_id, or filename") + + return encode_file_to_dict(values) + + def format(self) -> list[dict[str, Any]]: + try: + file_dict = {} + if self.file_data: + file_dict["file_data"] = self.file_data + if self.file_id: + file_dict["file_id"] = self.file_id + if self.filename: + file_dict["filename"] = self.filename + + return [{"type": "file", "file": file_dict}] + except Exception as e: + raise ValueError(f"Failed to format file for DSPy: {e}") + + def __str__(self): + return self.serialize_model() + + def __repr__(self): + parts = [] + if self.file_data is not None: + if self.file_data.startswith("data:"): + # file data has "data:text/plain;base64,..." format + mime_type = self.file_data.split(";")[0].split(":")[1] + len_data = len(self.file_data.split("base64,")[1]) if "base64," in self.file_data else len(self.file_data) + parts.append(f"file_data=") + else: + len_data = len(self.file_data) + parts.append(f"file_data=") + if self.file_id is not None: + parts.append(f"file_id='{self.file_id}'") + if self.filename is not None: + parts.append(f"filename='{self.filename}'") + return f"File({', '.join(parts)})" + + @classmethod + def from_path(cls, file_path: str, filename: str | None = None, mime_type: str | None = None) -> "File": + """Create a File from a local file path. + + Args: + file_path: Path to the file to read + filename: Optional filename to use (defaults to basename of path) + mime_type: Optional MIME type (defaults to auto-detection from file extension) + """ + if not os.path.isfile(file_path): + raise ValueError(f"File not found: {file_path}") + + with open(file_path, "rb") as f: + file_bytes = f.read() + + if filename is None: + filename = os.path.basename(file_path) + + if mime_type is None: + mime_type, _ = mimetypes.guess_type(file_path) + if mime_type is None: + mime_type = "application/octet-stream" + + encoded_data = base64.b64encode(file_bytes).decode("utf-8") + file_data = f"data:{mime_type};base64,{encoded_data}" + + return cls(file_data=file_data, filename=filename) + + @classmethod + def from_bytes( + cls, file_bytes: bytes, filename: str | None = None, mime_type: str = "application/octet-stream" + ) -> "File": + """Create a File from raw bytes. + + Args: + file_bytes: Raw bytes of the file + filename: Optional filename + mime_type: MIME type (defaults to 'application/octet-stream') + """ + encoded_data = base64.b64encode(file_bytes).decode("utf-8") + file_data = f"data:{mime_type};base64,{encoded_data}" + return cls(file_data=file_data, filename=filename) + + @classmethod + def from_file_id(cls, file_id: str, filename: str | None = None) -> "File": + """Create a File from an uploaded file ID.""" + return cls(file_id=file_id, filename=filename) + + +def encode_file_to_dict(file_input: Any) -> dict: + """ + Encode various file inputs to a dict with file_data, file_id, and/or filename. + + Args: + file_input: Can be a file path (str), bytes, or File instance. + + Returns: + dict: A dictionary with file_data, file_id, and/or filename keys. + """ + if isinstance(file_input, File): + result = {} + if file_input.file_data is not None: + result["file_data"] = file_input.file_data + if file_input.file_id is not None: + result["file_id"] = file_input.file_id + if file_input.filename is not None: + result["filename"] = file_input.filename + return result + + elif isinstance(file_input, str): + if os.path.isfile(file_input): + file_obj = File.from_path(file_input) + else: + raise ValueError(f"Unrecognized file string: {file_input}; must be a valid file path") + + return { + "file_data": file_obj.file_data, + "filename": file_obj.filename, + } + + elif isinstance(file_input, bytes): + file_obj = File.from_bytes(file_input) + return {"file_data": file_obj.file_data} + + else: + raise ValueError(f"Unsupported file input type: {type(file_input)}") diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index f9fc648cad..81d968703e 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -476,6 +476,11 @@ async def alitellm_responses_completion(request: dict[str, Any], num_retries: in def _convert_chat_request_to_responses_request(request: dict[str, Any]): + """ + Convert a chat request to a responses request + See https://platform.openai.com/docs/api-reference/responses/create for the responses API specification. + Also see https://platform.openai.com/docs/api-reference/chat/create for the chat API specification. + """ request = dict(request) if "messages" in request: content_blocks = [] @@ -525,6 +530,14 @@ def _convert_content_item_to_responses_format(item: dict[str, Any]) -> dict[str, "type": "input_text", "text": item.get("text", ""), } + elif item.get("type") == "file": + file = item.get("file", {}) + return { + "type": "input_file", + "file_data": file.get("file_data"), + "filename": file.get("filename"), + "file_id": file.get("file_id"), + } # For other items, return as-is return item diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py index 556913a74b..07934157fd 100644 --- a/dspy/utils/inspect_history.py +++ b/dspy/utils/inspect_history.py @@ -46,6 +46,13 @@ def pretty_print_history(history, n: int = 1): len_audio = len(c["input_audio"]["data"]) audio_str = f"