Skip to content

Commit 2ec47ee

Browse files
chore: Move OpenAPI schema and fix typos
1 parent f494cf2 commit 2ec47ee

File tree

5 files changed

+165
-127
lines changed

5 files changed

+165
-127
lines changed

nemoguardrails/server/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+

nemoguardrails/server/api.py

Lines changed: 15 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
1516
import asyncio
1617
import contextvars
1718
import importlib.util
@@ -27,17 +28,20 @@
2728

2829
from fastapi import FastAPI, Request
2930
from fastapi.middleware.cors import CORSMiddleware
30-
from pydantic import BaseModel, Field, root_validator, validator
31+
from pydantic import Field, root_validator, validator
3132
from starlette.responses import StreamingResponse
3233
from starlette.staticfiles import StaticFiles
3334

3435
from nemoguardrails import LLMRails, RailsConfig, utils
35-
from nemoguardrails.rails.llm.options import (
36-
GenerationLog,
37-
GenerationOptions,
38-
GenerationResponse,
39-
)
36+
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
4037
from nemoguardrails.server.datastore.datastore import DataStore
38+
from nemoguardrails.server.schemas.openai import (
39+
Choice,
40+
Model,
41+
ModelsResponse,
42+
OpenAIRequestFields,
43+
ResponseBody,
44+
)
4145
from nemoguardrails.streaming import StreamingHandler
4246

4347
logging.basicConfig(level=logging.INFO)
@@ -169,7 +173,7 @@ async def root_handler():
169173
app.single_config_id = None
170174

171175

172-
class RequestBody(BaseModel):
176+
class RequestBody(OpenAIRequestFields):
173177
config_id: Optional[str] = Field(
174178
default=os.getenv("DEFAULT_CONFIG_ID", None),
175179
description="The id of the configuration to be used. If not set, the default configuration will be used.",
@@ -208,47 +212,6 @@ class RequestBody(BaseModel):
208212
default=None,
209213
description="A state object that should be used to continue the interaction.",
210214
)
211-
# Standard OpenAI completion parameters
212-
model: Optional[str] = Field(
213-
default=None,
214-
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
215-
)
216-
max_tokens: Optional[int] = Field(
217-
default=None,
218-
description="The maximum number of tokens to generate.",
219-
)
220-
temperature: Optional[float] = Field(
221-
default=None,
222-
description="Sampling temperature to use.",
223-
)
224-
top_p: Optional[float] = Field(
225-
default=None,
226-
description="Top-p sampling parameter.",
227-
)
228-
stop: Optional[str] = Field(
229-
default=None,
230-
description="Stop sequences.",
231-
)
232-
presence_penalty: Optional[float] = Field(
233-
default=None,
234-
description="Presence penalty parameter.",
235-
)
236-
frequency_penalty: Optional[float] = Field(
237-
default=None,
238-
description="Frequency penalty parameter.",
239-
)
240-
function_call: Optional[dict] = Field(
241-
default=None,
242-
description="Function call parameter.",
243-
)
244-
logit_bias: Optional[dict] = Field(
245-
default=None,
246-
description="Logit bias parameter.",
247-
)
248-
log_probs: Optional[bool] = Field(
249-
default=None,
250-
description="Log probabilities parameter.",
251-
)
252215

253216
@root_validator(pre=True)
254217
def ensure_config_id(cls, data: Any) -> Any:
@@ -275,75 +238,6 @@ def ensure_config_ids(cls, v, values):
275238
return v
276239

277240

278-
class Choice(BaseModel):
279-
index: Optional[int] = Field(
280-
default=None, description="The index of the choice in the list of choices."
281-
)
282-
messages: Optional[dict] = Field(
283-
default=None, description="The message of the choice"
284-
)
285-
logprobs: Optional[dict] = Field(
286-
default=None, description="The log probabilities of the choice"
287-
)
288-
finish_reason: Optional[str] = Field(
289-
default=None, description="The reason the model stopped generating tokens."
290-
)
291-
292-
293-
class ResponseBody(BaseModel):
294-
# OpenAI-compatible fields
295-
id: Optional[str] = Field(
296-
default=None, description="A unique identifier for the chat completion."
297-
)
298-
object: str = Field(
299-
default="chat.completion",
300-
description="The object type, which is always chat.completion",
301-
)
302-
created: Optional[int] = Field(
303-
default=None,
304-
description="The Unix timestamp (in seconds) of when the chat completion was created.",
305-
)
306-
model: Optional[str] = Field(
307-
default=None, description="The model used for the chat completion."
308-
)
309-
choices: Optional[List[Choice]] = Field(
310-
default=None, description="A list of chat completion choices."
311-
)
312-
# NeMo-Guardrails specific fields for backward compatibility
313-
state: Optional[dict] = Field(
314-
default=None, description="State object for continuing the conversation."
315-
)
316-
llm_output: Optional[dict] = Field(
317-
default=None, description="Additional LLM output data."
318-
)
319-
output_data: Optional[dict] = Field(
320-
default=None, description="Additional output data."
321-
)
322-
log: Optional[dict] = Field(default=None, description="Generation log data.")
323-
324-
325-
class Model(BaseModel):
326-
id: str = Field(
327-
description="The model identifier, which can be referenced in the API endpoints."
328-
)
329-
object: str = Field(
330-
default="model", description="The object type, which is always 'model'."
331-
)
332-
created: int = Field(
333-
description="The Unix timestamp (in seconds) of when the model was created."
334-
)
335-
owned_by: str = Field(
336-
default="nemo-guardrails", description="The organization that owns the model."
337-
)
338-
339-
340-
class ModelsResponse(BaseModel):
341-
object: str = Field(
342-
default="list", description="The object type, which is always 'list'."
343-
)
344-
data: List[Model] = Field(description="The list of models.")
345-
346-
347241
@app.get(
348242
"/v1/models",
349243
response_model=ModelsResponse,
@@ -512,7 +406,7 @@ async def chat_completion(body: RequestBody, request: Request):
512406
choices=[
513407
Choice(
514408
index=0,
515-
messages={
409+
message={
516410
"content": f"Could not load the {config_ids} guardrails configuration. "
517411
f"An internal error has occurred.",
518412
"role": "assistant",
@@ -545,7 +439,7 @@ async def chat_completion(body: RequestBody, request: Request):
545439
choices=[
546440
Choice(
547441
index=0,
548-
messages={
442+
message={
549443
"content": "The `thread_id` must have a minimum length of 16 characters.",
550444
"role": "assistant",
551445
},
@@ -621,7 +515,7 @@ async def chat_completion(body: RequestBody, request: Request):
621515
"choices": [
622516
Choice(
623517
index=0,
624-
messages=bot_message,
518+
message=bot_message,
625519
finish_reason="stop",
626520
logprobs=None,
627521
)
@@ -647,7 +541,7 @@ async def chat_completion(body: RequestBody, request: Request):
647541
choices=[
648542
Choice(
649543
index=0,
650-
messages={
544+
message={
651545
"content": "Internal server error",
652546
"role": "assistant",
653547
},
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""OpenAI API schema definitions for the NeMo Guardrails server."""
17+
18+
from typing import List, Optional, Union
19+
20+
from pydantic import BaseModel, Field
21+
22+
23+
class OpenAIRequestFields(BaseModel):
24+
"""OpenAI API request fields that can be mixed into other request schemas."""
25+
26+
# Standard OpenAI completion parameters
27+
model: Optional[str] = Field(
28+
default=None,
29+
description="The model to use for chat completion. Maps to config_id for backward compatibility.",
30+
)
31+
max_tokens: Optional[int] = Field(
32+
default=None,
33+
description="The maximum number of tokens to generate.",
34+
)
35+
temperature: Optional[float] = Field(
36+
default=None,
37+
description="Sampling temperature to use.",
38+
)
39+
top_p: Optional[float] = Field(
40+
default=None,
41+
description="Top-p sampling parameter.",
42+
)
43+
stop: Optional[Union[str, List[str]]] = Field(
44+
default=None,
45+
description="Stop sequences.",
46+
)
47+
presence_penalty: Optional[float] = Field(
48+
default=None,
49+
description="Presence penalty parameter.",
50+
)
51+
frequency_penalty: Optional[float] = Field(
52+
default=None,
53+
description="Frequency penalty parameter.",
54+
)
55+
function_call: Optional[dict] = Field(
56+
default=None,
57+
description="Function call parameter.",
58+
)
59+
logit_bias: Optional[dict] = Field(
60+
default=None,
61+
description="Logit bias parameter.",
62+
)
63+
log_probs: Optional[bool] = Field(
64+
default=None,
65+
description="Log probabilities parameter.",
66+
)
67+
68+
69+
class Choice(BaseModel):
70+
"""OpenAI API choice structure in chat completion responses."""
71+
72+
index: Optional[int] = Field(
73+
default=None, description="The index of the choice in the list of choices."
74+
)
75+
message: Optional[dict] = Field(
76+
default=None, description="The message of the choice"
77+
)
78+
logprobs: Optional[dict] = Field(
79+
default=None, description="The log probabilities of the choice"
80+
)
81+
finish_reason: Optional[str] = Field(
82+
default=None, description="The reason the model stopped generating tokens."
83+
)
84+
85+
86+
class ResponseBody(BaseModel):
87+
"""OpenAI API response body with NeMo-Guardrails extensions."""
88+
89+
# OpenAI API fields
90+
id: Optional[str] = Field(
91+
default=None, description="A unique identifier for the chat completion."
92+
)
93+
object: str = Field(
94+
default="chat.completion",
95+
description="The object type, which is always chat.completion",
96+
)
97+
created: Optional[int] = Field(
98+
default=None,
99+
description="The Unix timestamp (in seconds) of when the chat completion was created.",
100+
)
101+
model: Optional[str] = Field(
102+
default=None, description="The model used for the chat completion."
103+
)
104+
choices: Optional[List[Choice]] = Field(
105+
default=None, description="A list of chat completion choices."
106+
)
107+
# NeMo-Guardrails specific fields for backward compatibility
108+
state: Optional[dict] = Field(
109+
default=None, description="State object for continuing the conversation."
110+
)
111+
llm_output: Optional[dict] = Field(
112+
default=None, description="Additional LLM output data."
113+
)
114+
output_data: Optional[dict] = Field(
115+
default=None, description="Additional output data."
116+
)
117+
log: Optional[dict] = Field(default=None, description="Generation log data.")
118+
119+
120+
class Model(BaseModel):
121+
"""OpenAI API model representation."""
122+
123+
id: str = Field(
124+
description="The model identifier, which can be referenced in the API endpoints."
125+
)
126+
object: str = Field(
127+
default="model", description="The object type, which is always 'model'."
128+
)
129+
created: int = Field(
130+
description="The Unix timestamp (in seconds) of when the model was created."
131+
)
132+
owned_by: str = Field(
133+
default="nemo-guardrails", description="The organization that owns the model."
134+
)
135+
136+
137+
class ModelsResponse(BaseModel):
138+
"""OpenAI API models list response."""
139+
140+
object: str = Field(
141+
default="list", description="The object type, which is always 'list'."
142+
)
143+
data: List[Model] = Field(description="The list of models.")

tests/test_server_calls_with_state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def _test_call(config_id):
3838
assert response.status_code == 200
3939
res = response.json()
4040
print(res)
41-
assert len(res["choices"][0]["messages"]) == 2
42-
assert res["choices"][0]["messages"]["content"] == "Hello!"
41+
assert len(res["choices"][0]["message"]) == 2
42+
assert res["choices"][0]["message"]["content"] == "Hello!"
4343
assert res.get("state")
4444

4545
# When making a second call with the returned state, the conversations should continue
@@ -60,7 +60,7 @@ def _test_call(config_id):
6060
},
6161
)
6262
res = response.json()
63-
assert res["choices"][0]["messages"]["content"] == "Hello again!"
63+
assert res["choices"][0]["message"]["content"] == "Hello again!"
6464

6565

6666
def test_1():

tests/test_threads.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def test_1():
5353
)
5454
assert response.status_code == 200
5555
res = response.json()
56-
assert len(res["choices"][0]["messages"]) == 2
57-
assert res["choices"][0]["messages"]["content"] == "Hello!"
56+
assert len(res["choices"][0]["message"]) == 2
57+
assert res["choices"][0]["message"]["content"] == "Hello!"
5858

5959
# When making a second call with the same thread_id, the conversations should continue
6060
# and we should get the "Hello again!" message.
@@ -72,7 +72,7 @@ def test_1():
7272
},
7373
)
7474
res = response.json()
75-
assert res["choices"][0]["messages"]["content"] == "Hello again!"
75+
assert res["choices"][0]["message"]["content"] == "Hello again!"
7676

7777

7878
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)