Skip to content

Commit e9f456d

Browse files
committed
Add matching support for responses API
Also send both completion and response outputs as GalileoMessages to fix validation error message
1 parent 18aef69 commit e9f456d

File tree

4 files changed

+622
-15
lines changed

4 files changed

+622
-15
lines changed

src/galileo/openai.py

Lines changed: 187 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from pydantic import BaseModel
5151
from wrapt import wrap_function_wrapper # type: ignore[import-untyped]
5252

53-
from galileo import GalileoLogger
53+
from galileo import GalileoLogger, Message, MessageRole, ToolCall, ToolCallFunction
5454
from galileo.decorator import galileo_context
5555
from galileo.utils import _get_timestamp
5656
from galileo.utils.serialization import serialize_to_str
@@ -103,7 +103,10 @@ class OpenAiInputData:
103103
OPENAI_CLIENT_METHODS = [
104104
OpenAiModuleDefinition(
105105
module="openai.resources.chat.completions", object="Completions", method="create", type="chat", sync=True
106-
)
106+
),
107+
OpenAiModuleDefinition(
108+
module="openai.resources.responses", object="Responses", method="create", type="response", sync=True
109+
),
107110
# Eventually add more OpenAI client library methods here
108111
]
109112

@@ -153,6 +156,56 @@ def wrapper(wrapped: Callable, instance: Any, args: dict, kwargs: dict) -> Any:
153156
return _with_galileo
154157

155158

159+
def _convert_to_galileo_message(data: Any, default_role: str = "user") -> Message:
160+
"""Convert OpenAI response data to a Galileo Message object."""
161+
if hasattr(data, "type") and data.type == "function_call":
162+
tool_call = ToolCall(
163+
id=getattr(data, "call_id", ""),
164+
function=ToolCallFunction(name=getattr(data, "name", ""), arguments=getattr(data, "arguments", "")),
165+
)
166+
return Message(content="", role=MessageRole.assistant, tool_calls=[tool_call])
167+
168+
if isinstance(data, dict) and data.get("type") == "function_call_output":
169+
output = data.get("output", "")
170+
if isinstance(output, dict):
171+
import json
172+
173+
content = json.dumps(output)
174+
else:
175+
content = str(output)
176+
177+
return Message(content=content, role=MessageRole.tool, tool_call_id=data.get("call_id", ""))
178+
179+
# Handle standard dictionary messages (Chat Completions format)
180+
if isinstance(data, dict):
181+
role = data.get("role", default_role)
182+
content = data.get("content", "")
183+
184+
# Handle tool calls if present
185+
tool_calls = data.get("tool_calls")
186+
galileo_tool_calls = None
187+
if tool_calls:
188+
galileo_tool_calls = []
189+
for tc in tool_calls:
190+
if isinstance(tc, dict) and "function" in tc:
191+
galileo_tool_calls.append(
192+
ToolCall(
193+
id=tc.get("id", ""),
194+
function=ToolCallFunction(
195+
name=tc["function"].get("name", ""), arguments=tc["function"].get("arguments", "")
196+
),
197+
)
198+
)
199+
200+
return Message(
201+
content=str(content) if content is not None else "",
202+
role=MessageRole(role),
203+
tool_calls=galileo_tool_calls,
204+
tool_call_id=data.get("tool_call_id"),
205+
)
206+
return Message(content=str(data), role=MessageRole(default_role))
207+
208+
156209
def _extract_chat_response(kwargs: dict) -> dict:
157210
"""Extracts the llm output from the response."""
158211
response = {"role": kwargs.get("role")}
@@ -213,6 +266,8 @@ def _extract_input_data_from_kwargs(
213266
prompt = kwargs.get("prompt")
214267
elif resource.type == "chat":
215268
prompt = kwargs.get("messages", [])
269+
elif resource.type == "response":
270+
prompt = kwargs.get("input", "")
216271

217272
parsed_temperature = float(
218273
kwargs.get("temperature", 1) if not isinstance(kwargs.get("temperature", 1), NotGiven) else 1
@@ -283,6 +338,17 @@ def _parse_usage(usage: Optional[dict] = None) -> Optional[dict]:
283338

284339
usage_dict = usage.copy() if isinstance(usage, dict) else usage.__dict__
285340

341+
# Handle Responses API field names (input_tokens/output_tokens) vs Chat Completions (prompt_tokens/completion_tokens)
342+
if "input_tokens" in usage_dict:
343+
usage_dict["prompt_tokens"] = usage_dict.pop("input_tokens")
344+
if "output_tokens" in usage_dict:
345+
usage_dict["completion_tokens"] = usage_dict.pop("output_tokens")
346+
347+
if "input_tokens_details" in usage_dict:
348+
usage_dict["prompt_tokens_details"] = usage_dict.pop("input_tokens_details")
349+
if "output_tokens_details" in usage_dict:
350+
usage_dict["completion_tokens_details"] = usage_dict.pop("output_tokens_details")
351+
286352
for tokens_details in ["prompt_tokens_details", "completion_tokens_details"]:
287353
if tokens_details in usage_dict and usage_dict[tokens_details] is not None:
288354
tokens_details_dict = (
@@ -295,6 +361,44 @@ def _parse_usage(usage: Optional[dict] = None) -> Optional[dict]:
295361
return usage_dict
296362

297363

364+
def _extract_responses_output(output_items: list) -> dict:
365+
"""Extract the final message and tool calls from Responses API output items."""
366+
final_message = None
367+
tool_calls = []
368+
369+
for item in output_items:
370+
if hasattr(item, "type") and item.type == "message":
371+
final_message = {"role": getattr(item, "role", "assistant"), "content": ""}
372+
373+
content = getattr(item, "content", [])
374+
if isinstance(content, list):
375+
text_parts = []
376+
for content_item in content:
377+
if hasattr(content_item, "text"):
378+
text_parts.append(content_item.text)
379+
elif isinstance(content_item, dict) and "text" in content_item:
380+
text_parts.append(content_item["text"])
381+
final_message["content"] = "".join(text_parts)
382+
else:
383+
final_message["content"] = str(content)
384+
385+
elif hasattr(item, "type") and item.type == "function_call":
386+
tool_call = {
387+
"id": getattr(item, "id", ""),
388+
"function": {"name": getattr(item, "name", ""), "arguments": getattr(item, "arguments", "")},
389+
"type": "function",
390+
}
391+
tool_calls.append(tool_call)
392+
393+
if final_message:
394+
if tool_calls:
395+
final_message["tool_calls"] = tool_calls
396+
return final_message
397+
if tool_calls:
398+
return {"role": "assistant", "tool_calls": tool_calls}
399+
return {"role": "assistant", "content": ""}
400+
401+
298402
def _extract_data_from_default_response(resource: OpenAiModuleDefinition, response: dict[str, Any]) -> Any:
299403
if response is None:
300404
return None, "<NoneType response returned from OpenAI>", None
@@ -325,6 +429,10 @@ def _extract_data_from_default_response(resource: OpenAiModuleDefinition, respon
325429
completion = (
326430
_extract_chat_response(choice.message.__dict__) if _is_openai_v1() else choice.get("message", None)
327431
)
432+
elif resource.type == "response":
433+
# Handle Responses API structure
434+
output = response.get("output", [])
435+
completion = _extract_responses_output(output)
328436

329437
usage = _parse_usage(response.get("usage"))
330438

@@ -335,10 +443,27 @@ def _extract_streamed_openai_response(resource: OpenAiModuleDefinition, chunks:
335443
completion = defaultdict(str) if resource.type == "chat" else ""
336444
model, usage = None, None
337445

446+
# For Responses API, we just need to find the final completed event
447+
if resource.type == "response":
448+
final_response = None
449+
338450
for chunk in chunks:
339451
if _is_openai_v1():
340452
chunk = chunk.__dict__
341453

454+
if resource.type == "response":
455+
chunk_type = chunk.get("type", "")
456+
457+
if chunk_type == "response.completed":
458+
final_response = chunk.get("response")
459+
if final_response:
460+
model = getattr(final_response, "model", None)
461+
usage_obj = getattr(final_response, "usage", None)
462+
if usage_obj:
463+
usage = _parse_usage(usage_obj.__dict__ if hasattr(usage_obj, "__dict__") else usage_obj)
464+
465+
continue
466+
342467
model = model or chunk.get("model", None) or None
343468
usage = chunk.get("usage", None)
344469

@@ -414,7 +539,15 @@ def get_response_for_chat() -> Any:
414539
or None
415540
)
416541

417-
return model, get_response_for_chat() if resource.type == "chat" else completion, usage
542+
if resource.type == "chat":
543+
return model, get_response_for_chat(), usage
544+
if resource.type == "response":
545+
if final_response:
546+
output_items = getattr(final_response, "output", [])
547+
response_message = _extract_responses_output(output_items)
548+
return model, response_message, usage
549+
return model, {"role": "assistant", "content": ""}, usage
550+
return model, completion, usage
418551

419552

420553
def _is_openai_v1() -> bool:
@@ -442,7 +575,14 @@ def _wrap(
442575
else:
443576
# If we don't have an active trace, start a new trace
444577
# We will conclude it at the end
445-
galileo_logger.start_trace(input=serialize_to_str(input_data.input), name=input_data.name)
578+
if isinstance(input_data.input, list):
579+
trace_input_messages = [_convert_to_galileo_message(msg) for msg in input_data.input]
580+
else:
581+
trace_input_messages = [_convert_to_galileo_message(input_data.input)]
582+
583+
# Serialize with "messages" wrapper for UI compatibility
584+
trace_input = {"messages": [msg.model_dump(exclude_none=True) for msg in trace_input_messages]}
585+
galileo_logger.start_trace(input=serialize_to_str(trace_input), name=input_data.name)
446586
should_complete_trace = True
447587

448588
try:
@@ -476,10 +616,17 @@ def _wrap(
476616

477617
duration_ns = round((end_time - start_time).total_seconds() * 1e9)
478618

619+
if isinstance(input_data.input, list):
620+
span_input = [_convert_to_galileo_message(msg) for msg in input_data.input]
621+
else:
622+
span_input = [_convert_to_galileo_message(input_data.input)]
623+
624+
span_output = _convert_to_galileo_message(completion, "assistant")
625+
479626
# Add a span to the current trace or span (if this is a nested trace)
480627
galileo_logger.add_llm_span(
481-
input=input_data.input,
482-
output=completion,
628+
input=span_input,
629+
output=span_output,
483630
tools=input_data.tools,
484631
name=input_data.name,
485632
model=model,
@@ -496,8 +643,19 @@ def _wrap(
496643

497644
# Conclude the trace if this is the top-level call
498645
if should_complete_trace:
646+
full_conversation = []
647+
648+
if isinstance(input_data.input, list):
649+
full_conversation.extend([_convert_to_galileo_message(msg) for msg in input_data.input])
650+
else:
651+
full_conversation.append(_convert_to_galileo_message(input_data.input))
652+
653+
full_conversation.append(span_output)
654+
655+
# Serialize with "messages" wrapper for UI compatibility
656+
trace_output = {"messages": [msg.model_dump(exclude_none=True) for msg in full_conversation]}
499657
galileo_logger.conclude(
500-
output=serialize_to_str(completion), duration_ns=duration_ns, status_code=status_code
658+
output=serialize_to_str(trace_output), duration_ns=duration_ns, status_code=status_code
501659
)
502660

503661
# we want to re-raise exception after we process openai_response
@@ -593,10 +751,17 @@ def _finalize(self) -> None:
593751
# TODO: make sure completion_start_time what we want
594752
duration_ns = round((end_time - self.completion_start_time).total_seconds() * 1e9)
595753

754+
if isinstance(self.input_data.input, list):
755+
span_input = [_convert_to_galileo_message(msg) for msg in self.input_data.input]
756+
else:
757+
span_input = [_convert_to_galileo_message(self.input_data.input)]
758+
759+
span_output = _convert_to_galileo_message(completion, "assistant")
760+
596761
# Add a span to the current trace or span (if this is a nested trace)
597762
self.logger.add_llm_span(
598-
input=self.input_data.input,
599-
output=completion,
763+
input=span_input,
764+
output=span_output,
600765
tools=self.input_data.tools,
601766
name=self.input_data.name,
602767
model=model,
@@ -611,7 +776,19 @@ def _finalize(self) -> None:
611776

612777
# Conclude the trace if this is the top-level call
613778
if self.should_complete_trace:
614-
self.logger.conclude(output=completion, duration_ns=duration_ns, status_code=self.status_code)
779+
full_conversation = []
780+
781+
if isinstance(self.input_data.input, list):
782+
full_conversation.extend([_convert_to_galileo_message(msg) for msg in self.input_data.input])
783+
else:
784+
full_conversation.append(_convert_to_galileo_message(self.input_data.input))
785+
786+
full_conversation.append(span_output)
787+
788+
trace_output = {"messages": [msg.model_dump(exclude_none=True) for msg in full_conversation]}
789+
self.logger.conclude(
790+
output=serialize_to_str(trace_output), duration_ns=duration_ns, status_code=self.status_code
791+
)
615792

616793

617794
class OpenAIGalileo:

tests/conftest.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88
from openai.types import CompletionUsage
99
from openai.types.chat import ChatCompletionMessage
1010
from openai.types.chat.chat_completion import ChatCompletion, Choice
11+
from openai.types.responses import (
12+
Response,
13+
ResponseFunctionToolCall,
14+
ResponseOutputMessage,
15+
ResponseOutputText,
16+
ResponseUsage,
17+
)
18+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
1119

1220
from galileo.config import GalileoPythonConfig
1321
from galileo.resources.models import DatasetContent, DatasetRow, DatasetRowValuesDict
@@ -81,6 +89,73 @@ def create_chat_completion() -> ChatCompletion:
8189
)
8290

8391

92+
@pytest.fixture
93+
def create_responses_response():
94+
"""Mock Responses API response for basic text generation."""
95+
96+
return Response(
97+
id="resp_test123",
98+
created_at=1758822441.0,
99+
model="gpt-4o",
100+
object="response",
101+
output=[
102+
ResponseOutputMessage(
103+
id="msg_test123",
104+
content=[
105+
ResponseOutputText(text="This is a test response", type="output_text", annotations=[], logprobs=[])
106+
],
107+
role="assistant",
108+
status="completed",
109+
type="message",
110+
)
111+
],
112+
parallel_tool_calls=True,
113+
tool_choice="auto",
114+
tools=[],
115+
usage=ResponseUsage(
116+
input_tokens=10,
117+
input_tokens_details=InputTokensDetails(cached_tokens=0),
118+
output_tokens=5,
119+
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
120+
total_tokens=15,
121+
),
122+
status="completed",
123+
)
124+
125+
126+
@pytest.fixture
127+
def create_responses_response_with_tools():
128+
"""Mock Responses API response with tool calls."""
129+
130+
return Response(
131+
id="resp_test456",
132+
created_at=1758822441.0,
133+
model="gpt-4o",
134+
object="response",
135+
output=[
136+
ResponseFunctionToolCall(
137+
id="fc_test456",
138+
name="get_weather",
139+
arguments='{"location": "San Francisco"}',
140+
type="function_call",
141+
call_id="call_test456",
142+
status="completed",
143+
)
144+
],
145+
parallel_tool_calls=True,
146+
tool_choice="auto",
147+
tools=[],
148+
usage=ResponseUsage(
149+
input_tokens=20,
150+
input_tokens_details=InputTokensDetails(cached_tokens=0),
151+
output_tokens=10,
152+
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
153+
total_tokens=30,
154+
),
155+
status="completed",
156+
)
157+
158+
84159
@pytest.fixture
85160
def test_dataset_row_id() -> None:
86161
str(uuid4())

0 commit comments

Comments
 (0)