1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+
1516import asyncio
1617import contextvars
1718import importlib .util
2728
2829from fastapi import FastAPI , Request
2930from fastapi .middleware .cors import CORSMiddleware
30- from pydantic import BaseModel , Field , root_validator , validator
31+ from pydantic import Field , root_validator , validator
3132from starlette .responses import StreamingResponse
3233from starlette .staticfiles import StaticFiles
3334
3435from nemoguardrails import LLMRails , RailsConfig , utils
35- from nemoguardrails .rails .llm .options import (
36- GenerationLog ,
37- GenerationOptions ,
38- GenerationResponse ,
39- )
36+ from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
4037from nemoguardrails .server .datastore .datastore import DataStore
38+ from nemoguardrails .server .schemas .openai import (
39+ Choice ,
40+ Model ,
41+ ModelsResponse ,
42+ OpenAIRequestFields ,
43+ ResponseBody ,
44+ )
4145from nemoguardrails .streaming import StreamingHandler
4246
4347logging .basicConfig (level = logging .INFO )
@@ -169,7 +173,7 @@ async def root_handler():
169173app .single_config_id = None
170174
171175
172- class RequestBody (BaseModel ):
176+ class RequestBody (OpenAIRequestFields ):
173177 config_id : Optional [str ] = Field (
174178 default = os .getenv ("DEFAULT_CONFIG_ID" , None ),
175179 description = "The id of the configuration to be used. If not set, the default configuration will be used." ,
@@ -208,47 +212,6 @@ class RequestBody(BaseModel):
208212 default = None ,
209213 description = "A state object that should be used to continue the interaction." ,
210214 )
211- # Standard OpenAI completion parameters
212- model : Optional [str ] = Field (
213- default = None ,
214- description = "The model to use for chat completion. Maps to config_id for backward compatibility." ,
215- )
216- max_tokens : Optional [int ] = Field (
217- default = None ,
218- description = "The maximum number of tokens to generate." ,
219- )
220- temperature : Optional [float ] = Field (
221- default = None ,
222- description = "Sampling temperature to use." ,
223- )
224- top_p : Optional [float ] = Field (
225- default = None ,
226- description = "Top-p sampling parameter." ,
227- )
228- stop : Optional [str ] = Field (
229- default = None ,
230- description = "Stop sequences." ,
231- )
232- presence_penalty : Optional [float ] = Field (
233- default = None ,
234- description = "Presence penalty parameter." ,
235- )
236- frequency_penalty : Optional [float ] = Field (
237- default = None ,
238- description = "Frequency penalty parameter." ,
239- )
240- function_call : Optional [dict ] = Field (
241- default = None ,
242- description = "Function call parameter." ,
243- )
244- logit_bias : Optional [dict ] = Field (
245- default = None ,
246- description = "Logit bias parameter." ,
247- )
248- log_probs : Optional [bool ] = Field (
249- default = None ,
250- description = "Log probabilities parameter." ,
251- )
252215
253216 @root_validator (pre = True )
254217 def ensure_config_id (cls , data : Any ) -> Any :
@@ -275,75 +238,6 @@ def ensure_config_ids(cls, v, values):
275238 return v
276239
277240
278- class Choice (BaseModel ):
279- index : Optional [int ] = Field (
280- default = None , description = "The index of the choice in the list of choices."
281- )
282- messages : Optional [dict ] = Field (
283- default = None , description = "The message of the choice"
284- )
285- logprobs : Optional [dict ] = Field (
286- default = None , description = "The log probabilities of the choice"
287- )
288- finish_reason : Optional [str ] = Field (
289- default = None , description = "The reason the model stopped generating tokens."
290- )
291-
292-
293- class ResponseBody (BaseModel ):
294- # OpenAI-compatible fields
295- id : Optional [str ] = Field (
296- default = None , description = "A unique identifier for the chat completion."
297- )
298- object : str = Field (
299- default = "chat.completion" ,
300- description = "The object type, which is always chat.completion" ,
301- )
302- created : Optional [int ] = Field (
303- default = None ,
304- description = "The Unix timestamp (in seconds) of when the chat completion was created." ,
305- )
306- model : Optional [str ] = Field (
307- default = None , description = "The model used for the chat completion."
308- )
309- choices : Optional [List [Choice ]] = Field (
310- default = None , description = "A list of chat completion choices."
311- )
312- # NeMo-Guardrails specific fields for backward compatibility
313- state : Optional [dict ] = Field (
314- default = None , description = "State object for continuing the conversation."
315- )
316- llm_output : Optional [dict ] = Field (
317- default = None , description = "Additional LLM output data."
318- )
319- output_data : Optional [dict ] = Field (
320- default = None , description = "Additional output data."
321- )
322- log : Optional [dict ] = Field (default = None , description = "Generation log data." )
323-
324-
325- class Model (BaseModel ):
326- id : str = Field (
327- description = "The model identifier, which can be referenced in the API endpoints."
328- )
329- object : str = Field (
330- default = "model" , description = "The object type, which is always 'model'."
331- )
332- created : int = Field (
333- description = "The Unix timestamp (in seconds) of when the model was created."
334- )
335- owned_by : str = Field (
336- default = "nemo-guardrails" , description = "The organization that owns the model."
337- )
338-
339-
340- class ModelsResponse (BaseModel ):
341- object : str = Field (
342- default = "list" , description = "The object type, which is always 'list'."
343- )
344- data : List [Model ] = Field (description = "The list of models." )
345-
346-
347241@app .get (
348242 "/v1/models" ,
349243 response_model = ModelsResponse ,
@@ -512,7 +406,7 @@ async def chat_completion(body: RequestBody, request: Request):
512406 choices = [
513407 Choice (
514408 index = 0 ,
515- messages = {
409+ message = {
516410 "content" : f"Could not load the { config_ids } guardrails configuration. "
517411 f"An internal error has occurred." ,
518412 "role" : "assistant" ,
@@ -545,7 +439,7 @@ async def chat_completion(body: RequestBody, request: Request):
545439 choices = [
546440 Choice (
547441 index = 0 ,
548- messages = {
442+ message = {
549443 "content" : "The `thread_id` must have a minimum length of 16 characters." ,
550444 "role" : "assistant" ,
551445 },
@@ -621,7 +515,7 @@ async def chat_completion(body: RequestBody, request: Request):
621515 "choices" : [
622516 Choice (
623517 index = 0 ,
624- messages = bot_message ,
518+ message = bot_message ,
625519 finish_reason = "stop" ,
626520 logprobs = None ,
627521 )
@@ -647,7 +541,7 @@ async def chat_completion(body: RequestBody, request: Request):
647541 choices = [
648542 Choice (
649543 index = 0 ,
650- messages = {
544+ message = {
651545 "content" : "Internal server error" ,
652546 "role" : "assistant" ,
653547 },
0 commit comments