Skip to content

Commit 8d809d4

Browse files
Minor fix
1 parent 7b9654b commit 8d809d4

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

nemoguardrails/server/api.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,18 +486,24 @@ async def chat_completion(body: RequestBody, request: Request):
486486
messages = thread_messages + messages
487487

488488
generation_options = body.options
489+
490+
# Initialize llm_params if not already set
491+
if generation_options.llm_params is None:
492+
generation_options.llm_params = {}
493+
494+
# Set OpenAI-compatible parameters in llm_params
489495
if body.max_tokens:
490-
generation_options.max_tokens = body.max_tokens
496+
generation_options.llm_params["max_tokens"] = body.max_tokens
491497
if body.temperature is not None:
492-
generation_options.temperature = body.temperature
498+
generation_options.llm_params["temperature"] = body.temperature
493499
if body.top_p is not None:
494-
generation_options.top_p = body.top_p
500+
generation_options.llm_params["top_p"] = body.top_p
495501
if body.stop:
496-
generation_options.stop = body.stop
502+
generation_options.llm_params["stop"] = body.stop
497503
if body.presence_penalty is not None:
498-
generation_options.presence_penalty = body.presence_penalty
504+
generation_options.llm_params["presence_penalty"] = body.presence_penalty
499505
if body.frequency_penalty is not None:
500-
generation_options.frequency_penalty = body.frequency_penalty
506+
generation_options.llm_params["frequency_penalty"] = body.frequency_penalty
501507

502508
if (
503509
body.stream

tests/test_threads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,4 @@ def test_with_redis():
140140
},
141141
)
142142
res = response.json()
143-
assert res["choices"]["message"][0]["content"] == "Hello again!"
143+
assert res["choices"][0]["message"]["content"] == "Hello again!"

0 commit comments

Comments
 (0)