2020import os .path
2121import re
2222import time
23+ import uuid
2324import warnings
2425from contextlib import asynccontextmanager
2526from typing import Any , List , Optional
@@ -207,10 +208,53 @@ class RequestBody(BaseModel):
207208 default = None ,
208209 description = "A state object that should be used to continue the interaction." ,
209210 )
211+ # Standard OpenAI completion parameters
212+ model : Optional [str ] = Field (
213+ default = None ,
214+ description = "The model to use for chat completion. Maps to config_id for backward compatibility." ,
215+ )
216+ max_tokens : Optional [int ] = Field (
217+ default = None ,
218+ description = "The maximum number of tokens to generate." ,
219+ )
220+ temperature : Optional [float ] = Field (
221+ default = None ,
222+ description = "Sampling temperature to use." ,
223+ )
224+ top_p : Optional [float ] = Field (
225+ default = None ,
226+ description = "Top-p sampling parameter." ,
227+ )
228+ stop : Optional [str ] = Field (
229+ default = None ,
230+ description = "Stop sequences." ,
231+ )
232+ presence_penalty : Optional [float ] = Field (
233+ default = None ,
234+ description = "Presence penalty parameter." ,
235+ )
236+ frequency_penalty : Optional [float ] = Field (
237+ default = None ,
238+ description = "Frequency penalty parameter." ,
239+ )
240+ function_call : Optional [dict ] = Field (
241+ default = None ,
242+ description = "Function call parameter." ,
243+ )
244+ logit_bias : Optional [dict ] = Field (
245+ default = None ,
246+ description = "Logit bias parameter." ,
247+ )
248+ log_probs : Optional [bool ] = Field (
249+ default = None ,
250+ description = "Log probabilities parameter." ,
251+ )
210252
211253 @root_validator (pre = True )
212254 def ensure_config_id (cls , data : Any ) -> Any :
213255 if isinstance (data , dict ):
256+ if data .get ("model" ) is not None and data .get ("config_id" ) is None :
257+ data ["config_id" ] = data ["model" ]
214258 if data .get ("config_id" ) is not None and data .get ("config_ids" ) is not None :
215259 raise ValueError (
216260 "Only one of config_id or config_ids should be specified"
@@ -231,25 +275,113 @@ def ensure_config_ids(cls, v, values):
231275 return v
232276
233277
278+ class Choice (BaseModel ):
279+ index : Optional [int ] = Field (
280+ default = None , description = "The index of the choice in the list of choices."
281+ )
282+ messages : Optional [dict ] = Field (
283+ default = None , description = "The message of the choice"
284+ )
285+ logprobs : Optional [dict ] = Field (
286+ default = None , description = "The log probabilities of the choice"
287+ )
288+ finish_reason : Optional [str ] = Field (
289+ default = None , description = "The reason the model stopped generating tokens."
290+ )
291+
292+
234293class ResponseBody (BaseModel ):
235- messages : List [dict ] = Field (
236- default = None , description = "The new messages in the conversation"
294+ # OpenAI-compatible fields
295+ id : Optional [str ] = Field (
296+ default = None , description = "A unique identifier for the chat completion."
237297 )
238- llm_output : Optional [ dict ] = Field (
239- default = None ,
240- description = "Contains any additional output coming from the LLM. " ,
298+ object : str = Field (
299+ default = "chat.completion" ,
300+ description = "The object type, which is always chat.completion " ,
241301 )
242- output_data : Optional [dict ] = Field (
302+ created : Optional [int ] = Field (
243303 default = None ,
244- description = "The output data, i.e. a dict with the values corresponding to the `output_vars`." ,
304+ description = "The Unix timestamp (in seconds) of when the chat completion was created." ,
305+ )
306+ model : Optional [str ] = Field (
307+ default = None , description = "The model used for the chat completion."
245308 )
246- log : Optional [GenerationLog ] = Field (
247- default = None , description = "Additional logging information ."
309+ choices : Optional [List [ Choice ] ] = Field (
310+ default = None , description = "A list of chat completion choices ."
248311 )
312+ # NeMo-Guardrails specific fields for backward compatibility
249313 state : Optional [dict ] = Field (
250- default = None ,
251- description = "A state object that should be used to continue the interaction in the future." ,
314+ default = None , description = "State object for continuing the conversation."
252315 )
316+ llm_output : Optional [dict ] = Field (
317+ default = None , description = "Additional LLM output data."
318+ )
319+ output_data : Optional [dict ] = Field (
320+ default = None , description = "Additional output data."
321+ )
322+ log : Optional [dict ] = Field (default = None , description = "Generation log data." )
323+
324+
325+ class Model (BaseModel ):
326+ id : str = Field (
327+ description = "The model identifier, which can be referenced in the API endpoints."
328+ )
329+ object : str = Field (
330+ default = "model" , description = "The object type, which is always 'model'."
331+ )
332+ created : int = Field (
333+ description = "The Unix timestamp (in seconds) of when the model was created."
334+ )
335+ owned_by : str = Field (
336+ default = "nemo-guardrails" , description = "The organization that owns the model."
337+ )
338+
339+
340+ class ModelsResponse (BaseModel ):
341+ object : str = Field (
342+ default = "list" , description = "The object type, which is always 'list'."
343+ )
344+ data : List [Model ] = Field (description = "The list of models." )
345+
346+
347+ @app .get (
348+ "/v1/models" ,
349+ response_model = ModelsResponse ,
350+ summary = "List available models" ,
351+ description = "Lists the currently available models, mapping guardrails configurations to OpenAI-compatible model format." ,
352+ )
353+ async def get_models ():
354+ """Returns the list of available models (guardrails configurations) in OpenAI-compatible format."""
355+
356+ # Use the same logic as get_rails_configs to find available configurations
357+ if app .single_config_mode :
358+ config_ids = [app .single_config_id ] if app .single_config_id else []
359+ else :
360+ config_ids = [
361+ f
362+ for f in os .listdir (app .rails_config_path )
363+ if os .path .isdir (os .path .join (app .rails_config_path , f ))
364+ and f [0 ] != "."
365+ and f [0 ] != "_"
366+ # Filter out all the configs for which there is no `config.yml` file.
367+ and (
368+ os .path .exists (os .path .join (app .rails_config_path , f , "config.yml" ))
369+ or os .path .exists (os .path .join (app .rails_config_path , f , "config.yaml" ))
370+ )
371+ ]
372+
373+ # Convert configurations to OpenAI model format
374+ models = []
375+ for config_id in config_ids :
376+ model = Model (
377+ id = config_id ,
378+ object = "model" ,
379+ created = int (time .time ()), # Use current time as created timestamp
380+ owned_by = "nemo-guardrails" ,
381+ )
382+ models .append (model )
383+
384+ return ModelsResponse (data = models )
253385
254386
255387@app .get (
@@ -372,15 +504,24 @@ async def chat_completion(body: RequestBody, request: Request):
372504 llm_rails = _get_rails (config_ids )
373505 except ValueError as ex :
374506 log .exception (ex )
375- return {
376- "messages" : [
377- {
378- "role" : "assistant" ,
379- "content" : f"Could not load the { config_ids } guardrails configuration. "
380- f"An internal error has occurred." ,
381- }
382- ]
383- }
507+ return ResponseBody (
508+ id = f"chatcmpl-{ uuid .uuid4 ()} " ,
509+ object = "chat.completion" ,
510+ created = int (time .time ()),
511+ model = config_ids [0 ] if config_ids else None ,
512+ choices = [
513+ Choice (
514+ index = 0 ,
515+ messages = {
516+ "content" : f"Could not load the { config_ids } guardrails configuration. "
517+ f"An internal error has occurred." ,
518+ "role" : "assistant" ,
519+ },
520+ finish_reason = "error" ,
521+ logprobs = None ,
522+ )
523+ ],
524+ )
384525
385526 try :
386527 messages = body .messages
@@ -396,14 +537,23 @@ async def chat_completion(body: RequestBody, request: Request):
396537
397538 # We make sure the `thread_id` meets the minimum complexity requirement.
398539 if len (body .thread_id ) < 16 :
399- return {
400- "messages" : [
401- {
402- "role" : "assistant" ,
403- "content" : "The `thread_id` must have a minimum length of 16 characters." ,
404- }
405- ]
406- }
540+ return ResponseBody (
541+ id = f"chatcmpl-{ uuid .uuid4 ()} " ,
542+ object = "chat.completion" ,
543+ created = int (time .time ()),
544+ model = None ,
545+ choices = [
546+ Choice (
547+ index = 0 ,
548+ messages = {
549+ "content" : "The `thread_id` must have a minimum length of 16 characters." ,
550+ "role" : "assistant" ,
551+ },
552+ finish_reason = "error" ,
553+ logprobs = None ,
554+ )
555+ ],
556+ )
407557
408558 # Fetch the existing thread messages. For easier management, we prepend
409559 # the string `thread-` to all thread keys.
@@ -413,6 +563,20 @@ async def chat_completion(body: RequestBody, request: Request):
413563 # And prepend them.
414564 messages = thread_messages + messages
415565
566+ generation_options = body .options
567+ if body .max_tokens :
568+ generation_options .max_tokens = body .max_tokens
569+ if body .temperature is not None :
570+ generation_options .temperature = body .temperature
571+ if body .top_p is not None :
572+ generation_options .top_p = body .top_p
573+ if body .stop :
574+ generation_options .stop = body .stop
575+ if body .presence_penalty is not None :
576+ generation_options .presence_penalty = body .presence_penalty
577+ if body .frequency_penalty is not None :
578+ generation_options .frequency_penalty = body .frequency_penalty
579+
416580 if (
417581 body .stream
418582 and llm_rails .config .streaming_supported
@@ -431,8 +595,6 @@ async def chat_completion(body: RequestBody, request: Request):
431595 )
432596 )
433597
434- # TODO: Add support for thread_ids in streaming mode
435-
436598 return StreamingResponse (streaming_handler )
437599 else :
438600 res = await llm_rails .generate_async (
@@ -450,22 +612,50 @@ async def chat_completion(body: RequestBody, request: Request):
450612 if body .thread_id :
451613 await datastore .set (datastore_key , json .dumps (messages + [bot_message ]))
452614
453- result = {"messages" : [bot_message ]}
615+ # Build the response with OpenAI-compatible format plus NeMo-Guardrails extensions
616+ response_kwargs = {
617+ "id" : f"chatcmpl-{ uuid .uuid4 ()} " ,
618+ "object" : "chat.completion" ,
619+ "created" : int (time .time ()),
620+ "model" : config_ids [0 ] if config_ids else None ,
621+ "choices" : [
622+ Choice (
623+ index = 0 ,
624+ messages = bot_message ,
625+ finish_reason = "stop" ,
626+ logprobs = None ,
627+ )
628+ ],
629+ }
454630
455- # If we have additional GenerationResponse fields, we return as well
631+ # If we have additional GenerationResponse fields, include them for backward compatibility
456632 if isinstance (res , GenerationResponse ):
457- result ["llm_output" ] = res .llm_output
458- result ["output_data" ] = res .output_data
459- result ["log" ] = res .log
460- result ["state" ] = res .state
633+ response_kwargs ["llm_output" ] = res .llm_output
634+ response_kwargs ["output_data" ] = res .output_data
635+ response_kwargs ["log" ] = res .log
636+ response_kwargs ["state" ] = res .state
461637
462- return result
638+ return ResponseBody ( ** response_kwargs )
463639
464640 except Exception as ex :
465641 log .exception (ex )
466- return {
467- "messages" : [{"role" : "assistant" , "content" : "Internal server error." }]
468- }
642+ return ResponseBody (
643+ id = f"chatcmpl-{ uuid .uuid4 ()} " ,
644+ object = "chat.completion" ,
645+ created = int (time .time ()),
646+ model = None ,
647+ choices = [
648+ Choice (
649+ index = 0 ,
650+ messages = {
651+ "content" : "Internal server error" ,
652+ "role" : "assistant" ,
653+ },
654+ finish_reason = "error" ,
655+ logprobs = None ,
656+ )
657+ ],
658+ )
469659
470660
471661# By default, there are no challenges
0 commit comments