Skip to content

Commit cd7241f

Browse files
committed
Add matching support for responses API
Also send both completion and response outputs as GalileoMessages to fix validation error message
1 parent 18aef69 commit cd7241f

File tree

4 files changed

+639
-15
lines changed

4 files changed

+639
-15
lines changed

src/galileo/openai.py

Lines changed: 207 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from pydantic import BaseModel
5151
from wrapt import wrap_function_wrapper # type: ignore[import-untyped]
5252

53-
from galileo import GalileoLogger
53+
from galileo import GalileoLogger, Message, MessageRole, ToolCall, ToolCallFunction
5454
from galileo.decorator import galileo_context
5555
from galileo.utils import _get_timestamp
5656
from galileo.utils.serialization import serialize_to_str
@@ -103,6 +103,9 @@ class OpenAiInputData:
103103
OPENAI_CLIENT_METHODS = [
104104
OpenAiModuleDefinition(
105105
module="openai.resources.chat.completions", object="Completions", method="create", type="chat", sync=True
106+
),
107+
OpenAiModuleDefinition(
108+
module="openai.resources.responses", object="Responses", method="create", type="response", sync=True
106109
)
107110
# Eventually add more OpenAI client library methods here
108111
]
@@ -153,6 +156,67 @@ def wrapper(wrapped: Callable, instance: Any, args: dict, kwargs: dict) -> Any:
153156
return _with_galileo
154157

155158

159+
160+
def _convert_to_galileo_message(data: Any, default_role: str = "user") -> Message:
161+
"""Convert OpenAI response data to a Galileo Message object."""
162+
if hasattr(data, 'type') and getattr(data, 'type') == 'function_call':
163+
tool_call = ToolCall(
164+
id=getattr(data, 'call_id', ''),
165+
function=ToolCallFunction(
166+
name=getattr(data, 'name', ''),
167+
arguments=getattr(data, 'arguments', '')
168+
),
169+
)
170+
return Message(
171+
content="",
172+
role=MessageRole.assistant,
173+
tool_calls=[tool_call]
174+
)
175+
176+
if isinstance(data, dict) and data.get("type") == "function_call_output":
177+
output = data.get("output", "")
178+
if isinstance(output, dict):
179+
import json
180+
content = json.dumps(output)
181+
else:
182+
content = str(output)
183+
184+
return Message(
185+
content=content,
186+
role=MessageRole.tool,
187+
tool_call_id=data.get("call_id", "")
188+
)
189+
190+
# Handle standard dictionary messages (Chat Completions format)
191+
if isinstance(data, dict):
192+
role = data.get("role", default_role)
193+
content = data.get("content", "")
194+
195+
# Handle tool calls if present
196+
tool_calls = data.get("tool_calls")
197+
galileo_tool_calls = None
198+
if tool_calls:
199+
galileo_tool_calls = []
200+
for tc in tool_calls:
201+
if isinstance(tc, dict) and "function" in tc:
202+
galileo_tool_calls.append(ToolCall(
203+
id=tc.get("id", ""),
204+
function=ToolCallFunction(
205+
name=tc["function"].get("name", ""),
206+
arguments=tc["function"].get("arguments", "")
207+
),
208+
))
209+
210+
return Message(
211+
content=str(content) if content is not None else "",
212+
role=MessageRole(role),
213+
tool_calls=galileo_tool_calls,
214+
tool_call_id=data.get("tool_call_id")
215+
)
216+
else:
217+
return Message(content=str(data), role=MessageRole(default_role))
218+
219+
156220
def _extract_chat_response(kwargs: dict) -> dict:
157221
"""Extracts the llm output from the response."""
158222
response = {"role": kwargs.get("role")}
@@ -213,6 +277,8 @@ def _extract_input_data_from_kwargs(
213277
prompt = kwargs.get("prompt")
214278
elif resource.type == "chat":
215279
prompt = kwargs.get("messages", [])
280+
elif resource.type == "response":
281+
prompt = kwargs.get("input", "")
216282

217283
parsed_temperature = float(
218284
kwargs.get("temperature", 1) if not isinstance(kwargs.get("temperature", 1), NotGiven) else 1
@@ -282,6 +348,17 @@ def _parse_usage(usage: Optional[dict] = None) -> Optional[dict]:
282348
return None
283349

284350
usage_dict = usage.copy() if isinstance(usage, dict) else usage.__dict__
351+
352+
# Handle Responses API field names (input_tokens/output_tokens) vs Chat Completions (prompt_tokens/completion_tokens)
353+
if "input_tokens" in usage_dict:
354+
usage_dict["prompt_tokens"] = usage_dict.pop("input_tokens")
355+
if "output_tokens" in usage_dict:
356+
usage_dict["completion_tokens"] = usage_dict.pop("output_tokens")
357+
358+
if "input_tokens_details" in usage_dict:
359+
usage_dict["prompt_tokens_details"] = usage_dict.pop("input_tokens_details")
360+
if "output_tokens_details" in usage_dict:
361+
usage_dict["completion_tokens_details"] = usage_dict.pop("output_tokens_details")
285362

286363
for tokens_details in ["prompt_tokens_details", "completion_tokens_details"]:
287364
if tokens_details in usage_dict and usage_dict[tokens_details] is not None:
@@ -295,6 +372,51 @@ def _parse_usage(usage: Optional[dict] = None) -> Optional[dict]:
295372
return usage_dict
296373

297374

375+
def _extract_responses_output(output_items: list) -> dict:
376+
"""Extract the final message and tool calls from Responses API output items."""
377+
final_message = None
378+
tool_calls = []
379+
380+
for item in output_items:
381+
if hasattr(item, 'type') and item.type == "message":
382+
final_message = {
383+
"role": getattr(item, 'role', 'assistant'),
384+
"content": ""
385+
}
386+
387+
content = getattr(item, 'content', [])
388+
if isinstance(content, list):
389+
text_parts = []
390+
for content_item in content:
391+
if hasattr(content_item, 'text'):
392+
text_parts.append(content_item.text)
393+
elif isinstance(content_item, dict) and 'text' in content_item:
394+
text_parts.append(content_item['text'])
395+
final_message["content"] = "".join(text_parts)
396+
else:
397+
final_message["content"] = str(content)
398+
399+
elif hasattr(item, 'type') and item.type == "function_call":
400+
tool_call = {
401+
"id": getattr(item, 'id', ''),
402+
"function": {
403+
"name": getattr(item, 'name', ''),
404+
"arguments": getattr(item, 'arguments', '')
405+
},
406+
"type": "function"
407+
}
408+
tool_calls.append(tool_call)
409+
410+
if final_message:
411+
if tool_calls:
412+
final_message["tool_calls"] = tool_calls
413+
return final_message
414+
elif tool_calls:
415+
return {"role": "assistant", "tool_calls": tool_calls}
416+
else:
417+
return {"role": "assistant", "content": ""}
418+
419+
298420
def _extract_data_from_default_response(resource: OpenAiModuleDefinition, response: dict[str, Any]) -> Any:
299421
if response is None:
300422
return None, "<NoneType response returned from OpenAI>", None
@@ -325,6 +447,10 @@ def _extract_data_from_default_response(resource: OpenAiModuleDefinition, respon
325447
completion = (
326448
_extract_chat_response(choice.message.__dict__) if _is_openai_v1() else choice.get("message", None)
327449
)
450+
elif resource.type == "response":
451+
# Handle Responses API structure
452+
output = response.get("output", [])
453+
completion = _extract_responses_output(output)
328454

329455
usage = _parse_usage(response.get("usage"))
330456

@@ -334,11 +460,28 @@ def _extract_data_from_default_response(resource: OpenAiModuleDefinition, respon
334460
def _extract_streamed_openai_response(resource: OpenAiModuleDefinition, chunks: Iterable) -> Any:
335461
completion = defaultdict(str) if resource.type == "chat" else ""
336462
model, usage = None, None
337-
463+
464+
# For Responses API, we just need to find the final completed event
465+
if resource.type == "response":
466+
final_response = None
467+
338468
for chunk in chunks:
339469
if _is_openai_v1():
340470
chunk = chunk.__dict__
341471

472+
if resource.type == "response":
473+
chunk_type = chunk.get("type", "")
474+
475+
if chunk_type == "response.completed":
476+
final_response = chunk.get("response")
477+
if final_response:
478+
model = getattr(final_response, 'model', None)
479+
usage_obj = getattr(final_response, 'usage', None)
480+
if usage_obj:
481+
usage = _parse_usage(usage_obj.__dict__ if hasattr(usage_obj, '__dict__') else usage_obj)
482+
483+
continue
484+
342485
model = model or chunk.get("model", None) or None
343486
usage = chunk.get("usage", None)
344487

@@ -414,7 +557,17 @@ def get_response_for_chat() -> Any:
414557
or None
415558
)
416559

417-
return model, get_response_for_chat() if resource.type == "chat" else completion, usage
560+
if resource.type == "chat":
561+
return model, get_response_for_chat(), usage
562+
elif resource.type == "response":
563+
if final_response:
564+
output_items = getattr(final_response, 'output', [])
565+
response_message = _extract_responses_output(output_items)
566+
return model, response_message, usage
567+
else:
568+
return model, {"role": "assistant", "content": ""}, usage
569+
else:
570+
return model, completion, usage
418571

419572

420573
def _is_openai_v1() -> bool:
@@ -442,7 +595,14 @@ def _wrap(
442595
else:
443596
# If we don't have an active trace, start a new trace
444597
# We will conclude it at the end
445-
galileo_logger.start_trace(input=serialize_to_str(input_data.input), name=input_data.name)
598+
if isinstance(input_data.input, list):
599+
trace_input_messages = [_convert_to_galileo_message(msg) for msg in input_data.input]
600+
else:
601+
trace_input_messages = [_convert_to_galileo_message(input_data.input)]
602+
603+
# Serialize with "messages" wrapper for UI compatibility
604+
trace_input = {"messages": [msg.model_dump(exclude_none=True) for msg in trace_input_messages]}
605+
galileo_logger.start_trace(input=serialize_to_str(trace_input), name=input_data.name)
446606
should_complete_trace = True
447607

448608
try:
@@ -476,10 +636,17 @@ def _wrap(
476636

477637
duration_ns = round((end_time - start_time).total_seconds() * 1e9)
478638

639+
if isinstance(input_data.input, list):
640+
span_input = [_convert_to_galileo_message(msg) for msg in input_data.input]
641+
else:
642+
span_input = [_convert_to_galileo_message(input_data.input)]
643+
644+
span_output = _convert_to_galileo_message(completion, "assistant")
645+
479646
# Add a span to the current trace or span (if this is a nested trace)
480647
galileo_logger.add_llm_span(
481-
input=input_data.input,
482-
output=completion,
648+
input=span_input,
649+
output=span_output,
483650
tools=input_data.tools,
484651
name=input_data.name,
485652
model=model,
@@ -496,8 +663,19 @@ def _wrap(
496663

497664
# Conclude the trace if this is the top-level call
498665
if should_complete_trace:
666+
full_conversation = []
667+
668+
if isinstance(input_data.input, list):
669+
full_conversation.extend([_convert_to_galileo_message(msg) for msg in input_data.input])
670+
else:
671+
full_conversation.append(_convert_to_galileo_message(input_data.input))
672+
673+
full_conversation.append(span_output)
674+
675+
# Serialize with "messages" wrapper for UI compatibility
676+
trace_output = {"messages": [msg.model_dump(exclude_none=True) for msg in full_conversation]}
499677
galileo_logger.conclude(
500-
output=serialize_to_str(completion), duration_ns=duration_ns, status_code=status_code
678+
output=serialize_to_str(trace_output), duration_ns=duration_ns, status_code=status_code
501679
)
502680

503681
# we want to re-raise exception after we process openai_response
@@ -593,10 +771,17 @@ def _finalize(self) -> None:
593771
# TODO: make sure completion_start_time what we want
594772
duration_ns = round((end_time - self.completion_start_time).total_seconds() * 1e9)
595773

774+
if isinstance(self.input_data.input, list):
775+
span_input = [_convert_to_galileo_message(msg) for msg in self.input_data.input]
776+
else:
777+
span_input = [_convert_to_galileo_message(self.input_data.input)]
778+
779+
span_output = _convert_to_galileo_message(completion, "assistant")
780+
596781
# Add a span to the current trace or span (if this is a nested trace)
597782
self.logger.add_llm_span(
598-
input=self.input_data.input,
599-
output=completion,
783+
input=span_input,
784+
output=span_output,
600785
tools=self.input_data.tools,
601786
name=self.input_data.name,
602787
model=model,
@@ -611,7 +796,19 @@ def _finalize(self) -> None:
611796

612797
# Conclude the trace if this is the top-level call
613798
if self.should_complete_trace:
614-
self.logger.conclude(output=completion, duration_ns=duration_ns, status_code=self.status_code)
799+
full_conversation = []
800+
801+
if isinstance(self.input_data.input, list):
802+
full_conversation.extend([_convert_to_galileo_message(msg) for msg in self.input_data.input])
803+
else:
804+
full_conversation.append(_convert_to_galileo_message(self.input_data.input))
805+
806+
full_conversation.append(span_output)
807+
808+
trace_output = {"messages": [msg.model_dump(exclude_none=True) for msg in full_conversation]}
809+
self.logger.conclude(
810+
output=serialize_to_str(trace_output), duration_ns=duration_ns, status_code=self.status_code
811+
)
615812

616813

617814
class OpenAIGalileo:

0 commit comments

Comments
 (0)