Skip to content

Commit e5ac825

Browse files
feat: Modify endpoints for OpenAPI compatibility
1 parent 5d33e2b commit e5ac825

File tree

4 files changed

+304
-52
lines changed

4 files changed

+304
-52
lines changed

nemoguardrails/server/api.py

Lines changed: 230 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os.path
2121
import re
2222
import time
23+
import uuid
2324
import warnings
2425
from contextlib import asynccontextmanager
2526
from typing import Any, List, Optional
@@ -207,10 +208,53 @@ class RequestBody(BaseModel):
207208
default=None,
208209
description="A state object that should be used to continue the interaction.",
209210
)
211+
# Standard OpenAI completion parameters
212+
model: Optional[str] = Field(
213+
default=None,
214+
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
215+
)
216+
max_tokens: Optional[int] = Field(
217+
default=None,
218+
description="The maximum number of tokens to generate.",
219+
)
220+
temperature: Optional[float] = Field(
221+
default=None,
222+
description="Sampling temperature to use.",
223+
)
224+
top_p: Optional[float] = Field(
225+
default=None,
226+
description="Top-p sampling parameter.",
227+
)
228+
stop: Optional[str] = Field(
229+
default=None,
230+
description="Stop sequences.",
231+
)
232+
presence_penalty: Optional[float] = Field(
233+
default=None,
234+
description="Presence penalty parameter.",
235+
)
236+
frequency_penalty: Optional[float] = Field(
237+
default=None,
238+
description="Frequency penalty parameter.",
239+
)
240+
function_call: Optional[dict] = Field(
241+
default=None,
242+
description="Function call parameter.",
243+
)
244+
logit_bias: Optional[dict] = Field(
245+
default=None,
246+
description="Logit bias parameter.",
247+
)
248+
log_probs: Optional[bool] = Field(
249+
default=None,
250+
description="Log probabilities parameter.",
251+
)
210252

211253
@root_validator(pre=True)
212254
def ensure_config_id(cls, data: Any) -> Any:
213255
if isinstance(data, dict):
256+
if data.get("model") is not None and data.get("config_id") is None:
257+
data["config_id"] = data["model"]
214258
if data.get("config_id") is not None and data.get("config_ids") is not None:
215259
raise ValueError(
216260
"Only one of config_id or config_ids should be specified"
@@ -231,25 +275,113 @@ def ensure_config_ids(cls, v, values):
231275
return v
232276

233277

278+
class Choice(BaseModel):
279+
index: Optional[int] = Field(
280+
default=None, description="The index of the choice in the list of choices."
281+
)
282+
messages: Optional[dict] = Field(
283+
default=None, description="The message of the choice"
284+
)
285+
logprobs: Optional[dict] = Field(
286+
default=None, description="The log probabilities of the choice"
287+
)
288+
finish_reason: Optional[str] = Field(
289+
default=None, description="The reason the model stopped generating tokens."
290+
)
291+
292+
234293
class ResponseBody(BaseModel):
235-
messages: List[dict] = Field(
236-
default=None, description="The new messages in the conversation"
294+
# OpenAI-compatible fields
295+
id: Optional[str] = Field(
296+
default=None, description="A unique identifier for the chat completion."
237297
)
238-
llm_output: Optional[dict] = Field(
239-
default=None,
240-
description="Contains any additional output coming from the LLM.",
298+
object: str = Field(
299+
default="chat.completion",
300+
description="The object type, which is always chat.completion",
241301
)
242-
output_data: Optional[dict] = Field(
302+
created: Optional[int] = Field(
243303
default=None,
244-
description="The output data, i.e. a dict with the values corresponding to the `output_vars`.",
304+
description="The Unix timestamp (in seconds) of when the chat completion was created.",
305+
)
306+
model: Optional[str] = Field(
307+
default=None, description="The model used for the chat completion."
245308
)
246-
log: Optional[GenerationLog] = Field(
247-
default=None, description="Additional logging information."
309+
choices: Optional[List[Choice]] = Field(
310+
default=None, description="A list of chat completion choices."
248311
)
312+
# NeMo-Guardrails specific fields for backward compatibility
249313
state: Optional[dict] = Field(
250-
default=None,
251-
description="A state object that should be used to continue the interaction in the future.",
314+
default=None, description="State object for continuing the conversation."
252315
)
316+
llm_output: Optional[dict] = Field(
317+
default=None, description="Additional LLM output data."
318+
)
319+
output_data: Optional[dict] = Field(
320+
default=None, description="Additional output data."
321+
)
322+
log: Optional[dict] = Field(default=None, description="Generation log data.")
323+
324+
325+
class Model(BaseModel):
326+
id: str = Field(
327+
description="The model identifier, which can be referenced in the API endpoints."
328+
)
329+
object: str = Field(
330+
default="model", description="The object type, which is always 'model'."
331+
)
332+
created: int = Field(
333+
description="The Unix timestamp (in seconds) of when the model was created."
334+
)
335+
owned_by: str = Field(
336+
default="nemo-guardrails", description="The organization that owns the model."
337+
)
338+
339+
340+
class ModelsResponse(BaseModel):
341+
object: str = Field(
342+
default="list", description="The object type, which is always 'list'."
343+
)
344+
data: List[Model] = Field(description="The list of models.")
345+
346+
347+
@app.get(
348+
"/v1/models",
349+
response_model=ModelsResponse,
350+
summary="List available models",
351+
description="Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format.",
352+
)
353+
async def get_models():
354+
"""Returns the list of available models (guardrails configurations) in OpenAI-compatible format."""
355+
356+
# Use the same logic as get_rails_configs to find available configurations
357+
if app.single_config_mode:
358+
config_ids = [app.single_config_id] if app.single_config_id else []
359+
else:
360+
config_ids = [
361+
f
362+
for f in os.listdir(app.rails_config_path)
363+
if os.path.isdir(os.path.join(app.rails_config_path, f))
364+
and f[0] != "."
365+
and f[0] != "_"
366+
# Filter out all the configs for which there is no `config.yml` file.
367+
and (
368+
os.path.exists(os.path.join(app.rails_config_path, f, "config.yml"))
369+
or os.path.exists(os.path.join(app.rails_config_path, f, "config.yaml"))
370+
)
371+
]
372+
373+
# Convert configurations to OpenAI model format
374+
models = []
375+
for config_id in config_ids:
376+
model = Model(
377+
id=config_id,
378+
object="model",
379+
created=int(time.time()), # Use current time as created timestamp
380+
owned_by="nemo-guardrails",
381+
)
382+
models.append(model)
383+
384+
return ModelsResponse(data=models)
253385

254386

255387
@app.get(
@@ -372,15 +504,24 @@ async def chat_completion(body: RequestBody, request: Request):
372504
llm_rails = _get_rails(config_ids)
373505
except ValueError as ex:
374506
log.exception(ex)
375-
return {
376-
"messages": [
377-
{
378-
"role": "assistant",
379-
"content": f"Could not load the {config_ids} guardrails configuration. "
380-
f"An internal error has occurred.",
381-
}
382-
]
383-
}
507+
return ResponseBody(
508+
id=f"chatcmpl-{uuid.uuid4()}",
509+
object="chat.completion",
510+
created=int(time.time()),
511+
model=config_ids[0] if config_ids else None,
512+
choices=[
513+
Choice(
514+
index=0,
515+
messages={
516+
"content": f"Could not load the {config_ids} guardrails configuration. "
517+
f"An internal error has occurred.",
518+
"role": "assistant",
519+
},
520+
finish_reason="error",
521+
logprobs=None,
522+
)
523+
],
524+
)
384525

385526
try:
386527
messages = body.messages
@@ -396,14 +537,23 @@ async def chat_completion(body: RequestBody, request: Request):
396537

397538
# We make sure the `thread_id` meets the minimum complexity requirement.
398539
if len(body.thread_id) < 16:
399-
return {
400-
"messages": [
401-
{
402-
"role": "assistant",
403-
"content": "The `thread_id` must have a minimum length of 16 characters.",
404-
}
405-
]
406-
}
540+
return ResponseBody(
541+
id=f"chatcmpl-{uuid.uuid4()}",
542+
object="chat.completion",
543+
created=int(time.time()),
544+
model=None,
545+
choices=[
546+
Choice(
547+
index=0,
548+
messages={
549+
"content": "The `thread_id` must have a minimum length of 16 characters.",
550+
"role": "assistant",
551+
},
552+
finish_reason="error",
553+
logprobs=None,
554+
)
555+
],
556+
)
407557

408558
# Fetch the existing thread messages. For easier management, we prepend
409559
# the string `thread-` to all thread keys.
@@ -413,6 +563,20 @@ async def chat_completion(body: RequestBody, request: Request):
413563
# And prepend them.
414564
messages = thread_messages + messages
415565

566+
generation_options = body.options
567+
if body.max_tokens:
568+
generation_options.max_tokens = body.max_tokens
569+
if body.temperature is not None:
570+
generation_options.temperature = body.temperature
571+
if body.top_p is not None:
572+
generation_options.top_p = body.top_p
573+
if body.stop:
574+
generation_options.stop = body.stop
575+
if body.presence_penalty is not None:
576+
generation_options.presence_penalty = body.presence_penalty
577+
if body.frequency_penalty is not None:
578+
generation_options.frequency_penalty = body.frequency_penalty
579+
416580
if (
417581
body.stream
418582
and llm_rails.config.streaming_supported
@@ -431,8 +595,6 @@ async def chat_completion(body: RequestBody, request: Request):
431595
)
432596
)
433597

434-
# TODO: Add support for thread_ids in streaming mode
435-
436598
return StreamingResponse(streaming_handler)
437599
else:
438600
res = await llm_rails.generate_async(
@@ -450,22 +612,50 @@ async def chat_completion(body: RequestBody, request: Request):
450612
if body.thread_id:
451613
await datastore.set(datastore_key, json.dumps(messages + [bot_message]))
452614

453-
result = {"messages": [bot_message]}
615+
# Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
616+
response_kwargs = {
617+
"id": f"chatcmpl-{uuid.uuid4()}",
618+
"object": "chat.completion",
619+
"created": int(time.time()),
620+
"model": config_ids[0] if config_ids else None,
621+
"choices": [
622+
Choice(
623+
index=0,
624+
messages=bot_message,
625+
finish_reason="stop",
626+
logprobs=None,
627+
)
628+
],
629+
}
454630

455-
# If we have additional GenerationResponse fields, we return as well
631+
# If we have additional GenerationResponse fields, include them for backward compatibility
456632
if isinstance(res, GenerationResponse):
457-
result["llm_output"] = res.llm_output
458-
result["output_data"] = res.output_data
459-
result["log"] = res.log
460-
result["state"] = res.state
633+
response_kwargs["llm_output"] = res.llm_output
634+
response_kwargs["output_data"] = res.output_data
635+
response_kwargs["log"] = res.log
636+
response_kwargs["state"] = res.state
461637

462-
return result
638+
return ResponseBody(**response_kwargs)
463639

464640
except Exception as ex:
465641
log.exception(ex)
466-
return {
467-
"messages": [{"role": "assistant", "content": "Internal server error."}]
468-
}
642+
return ResponseBody(
643+
id=f"chatcmpl-{uuid.uuid4()}",
644+
object="chat.completion",
645+
created=int(time.time()),
646+
model=None,
647+
choices=[
648+
Choice(
649+
index=0,
650+
messages={
651+
"content": "Internal server error",
652+
"role": "assistant",
653+
},
654+
finish_reason="error",
655+
logprobs=None,
656+
)
657+
],
658+
)
469659

470660

471661
# By default, there are no challenges

0 commit comments

Comments
 (0)