Skip to content

Commit 617ac25

Browse files
committed
Refactor logic to live in LoggingCallbackHandler
1 parent 4a98503 commit 617ac25

File tree

4 files changed

+216
-33
lines changed

4 files changed

+216
-33
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -211,45 +211,13 @@ def _prepare_callbacks(
211211
return logging_callbacks
212212

213213

214-
def _log_model_and_base_url(llm: Union[BaseLanguageModel, Runnable]) -> None:
215-
"""Extract and log the model and base URL from an LLM instance."""
216-
model_name = _infer_model_name(llm)
217-
base_url = None
218-
219-
# If llm is a `ChatNIM` instance, we expect its `client` to be an `OpenAI` client with a `base_url` attribute.
220-
if hasattr(llm, "client"):
221-
client = getattr(llm, "client")
222-
if hasattr(client, "base_url"):
223-
base_url = str(client.base_url)
224-
else:
225-
# If llm is a `ChatNVIDIA` instance or other provider, check common attribute names that store the base URL.
226-
for attr in [
227-
"base_url",
228-
"openai_api_base",
229-
"azure_endpoint",
230-
"api_base",
231-
"endpoint",
232-
]:
233-
if hasattr(llm, attr):
234-
value = getattr(llm, attr, None)
235-
if value:
236-
base_url = str(value)
237-
break
238-
239-
if base_url:
240-
log.info(f"Invoking LLM: model={model_name}, url={base_url}")
241-
else:
242-
log.info(f"Invoking LLM: model={model_name}")
243-
244-
245214
async def _invoke_with_string_prompt(
246215
llm: Union[BaseLanguageModel, Runnable],
247216
prompt: str,
248217
callbacks: BaseCallbackManager,
249218
):
250219
"""Invoke LLM with string prompt."""
251220
try:
252-
_log_model_and_base_url(llm)
253221
return await llm.ainvoke(prompt, config=RunnableConfig(callbacks=callbacks))
254222
except Exception as e:
255223
raise LLMCallException(e)
@@ -264,7 +232,6 @@ async def _invoke_with_message_list(
264232
messages = _convert_messages_to_langchain_format(prompt)
265233

266234
try:
267-
_log_model_and_base_url(llm)
268235
return await llm.ainvoke(messages, config=RunnableConfig(callbacks=callbacks))
269236
except Exception as e:
270237
raise LLMCallException(e)

nemoguardrails/logging/callbacks.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from nemoguardrails.logging.explain import LLMCallInfo
3333
from nemoguardrails.logging.processing_log import processing_log_var
3434
from nemoguardrails.logging.stats import LLMStats
35+
from nemoguardrails.logging.utils import extract_model_name_and_base_url
3536
from nemoguardrails.utils import new_uuid
3637

3738
log = logging.getLogger(__name__)
@@ -64,6 +65,15 @@ async def on_llm_start(
6465
if explain_info:
6566
explain_info.llm_calls.append(llm_call_info)
6667

68+
# Log model name and base URL
69+
model_name, base_url = extract_model_name_and_base_url(serialized)
70+
if base_url:
71+
log.info(f"Invoking LLM: model={model_name}, url={base_url}")
72+
elif model_name:
73+
log.info(f"Invoking LLM: model={model_name}")
74+
else:
75+
log.info("Invoking LLM")
76+
6777
log.info("Invocation Params :: %s", kwargs.get("invocation_params", {}))
6878
log.info(
6979
"Prompt :: %s",
@@ -105,6 +115,15 @@ async def on_chat_model_start(
105115
if explain_info:
106116
explain_info.llm_calls.append(llm_call_info)
107117

118+
# Log model name and base URL
119+
model_name, base_url = extract_model_name_and_base_url(serialized)
120+
if base_url:
121+
log.info(f"Invoking LLM: model={model_name}, url={base_url}")
122+
elif model_name:
123+
log.info(f"Invoking LLM: model={model_name}")
124+
else:
125+
log.info("Invoking LLM")
126+
108127
type_map = {
109128
"human": "User",
110129
"ai": "Bot",

nemoguardrails/logging/utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import re
18+
from typing import Any, Dict, Optional
19+
20+
log = logging.getLogger(__name__)
21+
22+
23+
def extract_model_name_and_base_url(
24+
serialized: Dict[str, Any]
25+
) -> tuple[Optional[str], Optional[str]]:
26+
"""Extract model name and base URL from serialized LLM parameters.
27+
28+
Args:
29+
serialized: The serialized LLM configuration
30+
31+
Returns:
32+
A tuple of (model_name, base_url). Either value can be None if not found
33+
"""
34+
model_name = None
35+
base_url = None
36+
37+
# Case 1: Try to extract from kwargs (we expect kwargs to be populated for the `ChatOpenAI` class).
38+
if "kwargs" in serialized:
39+
kwargs = serialized["kwargs"]
40+
41+
# Check for model_name in kwargs (ChatOpenAI attribute)
42+
if "model_name" in kwargs and kwargs["model_name"]:
43+
model_name = str(kwargs["model_name"])
44+
45+
# Check for openai_api_base in kwargs (ChatOpenAI attribute)
46+
if "openai_api_base" in kwargs and kwargs["openai_api_base"]:
47+
base_url = str(kwargs["openai_api_base"])
48+
49+
# Case 2: For other providers, parse `repr`, a string representation of the provider class. We don't have
50+
# a reference to the actual class, so we need to parse the string representation.
51+
if "repr" in serialized and isinstance(serialized["repr"], str):
52+
repr_str = serialized["repr"]
53+
54+
# Extract model name. We expect the property to be formatted like model='...' or model_name='...'
55+
if not model_name:
56+
match = re.search(r"model(?:_name)?=['\"]([^'\"]+)['\"]", repr_str)
57+
if match:
58+
model_name = match.group(1)
59+
60+
# Extract base URL. The propety name may vary between providers, so try common attribute patterns.
61+
if not base_url:
62+
url_attrs = [
63+
"api_base",
64+
"api_host",
65+
"azure_endpoint",
66+
"base_url",
67+
"endpoint",
68+
"endpoint_url",
69+
"openai_api_base",
70+
]
71+
for attr in url_attrs:
72+
match = re.search(rf"{attr}=['\"]([^'\"]+)['\"]", repr_str)
73+
if match:
74+
base_url = match.group(1)
75+
break
76+
77+
return model_name, base_url

tests/test_callbacks.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from nemoguardrails.logging.callbacks import LoggingCallbackHandler
3232
from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo
3333
from nemoguardrails.logging.stats import LLMStats
34+
from nemoguardrails.logging.utils import extract_model_name_and_base_url
3435

3536

3637
@pytest.mark.asyncio
@@ -261,3 +262,122 @@ def __init__(self, content, msg_type):
261262
assert logged_prompt is not None
262263
assert "[cyan]Custom[/]" in logged_prompt
263264
assert "[cyan]Function[/]" in logged_prompt
265+
266+
267+
def test_extract_model_and_url_from_kwargs():
268+
"""Test extracting model_name and openai_api_base from kwargs (ChatOpenAI case)."""
269+
serialized = {
270+
"kwargs": {
271+
"model_name": "gpt-4",
272+
"openai_api_base": "https://api.openai.com/v1",
273+
"temperature": 0.7,
274+
}
275+
}
276+
277+
model_name, base_url = extract_model_name_and_base_url(serialized)
278+
279+
assert model_name == "gpt-4"
280+
assert base_url == "https://api.openai.com/v1"
281+
282+
283+
def test_extract_model_and_url_from_repr():
284+
"""Test extracting from repr string (ChatNIM case)."""
285+
# Property values in single-quotes
286+
serialized = {
287+
"kwargs": {"temperature": 0.1},
288+
"repr": "ChatNIM(model='meta/llama-3.3-70b-instruct', client=<openai.OpenAI object at 0x10d8e4e90>, endpoint_url='https://nim.int.aire.nvidia.com/v1')",
289+
}
290+
291+
model_name, base_url = extract_model_name_and_base_url(serialized)
292+
293+
assert model_name == "meta/llama-3.3-70b-instruct"
294+
assert base_url == "https://nim.int.aire.nvidia.com/v1"
295+
296+
# Property values in double-quotes
297+
serialized = {
298+
"repr": 'ChatOpenAI(model="gpt-3.5-turbo", base_url="https://custom.api.com/v1")'
299+
}
300+
301+
model_name, base_url = extract_model_name_and_base_url(serialized)
302+
303+
assert model_name == "gpt-3.5-turbo"
304+
assert base_url == "https://custom.api.com/v1"
305+
306+
# Model is stored in the `model_name` property
307+
serialized = {
308+
"repr": "SomeProvider(model_name='custom-model-v2', api_base='https://example.com')"
309+
}
310+
311+
model_name, base_url = extract_model_name_and_base_url(serialized)
312+
313+
assert model_name == "custom-model-v2"
314+
assert base_url == "https://example.com"
315+
316+
317+
def test_extract_model_and_url_from_various_url_properties():
318+
"""Test extracting various URL property names."""
319+
test_cases = [
320+
("api_base='https://api1.com'", "https://api1.com"),
321+
("api_host='https://api2.com'", "https://api2.com"),
322+
("azure_endpoint='https://azure.com'", "https://azure.com"),
323+
("endpoint='https://endpoint.com'", "https://endpoint.com"),
324+
("openai_api_base='https://openai.com'", "https://openai.com"),
325+
]
326+
327+
for url_pattern, expected_url in test_cases:
328+
serialized = {"repr": f"Provider(model='test-model', {url_pattern})"}
329+
model_name, base_url = extract_model_name_and_base_url(serialized)
330+
assert base_url == expected_url, f"Failed for pattern: {url_pattern}"
331+
332+
333+
def test_extract_model_and_url_kwargs_priority_over_repr():
334+
"""Test that kwargs values, if present, take priority over repr values."""
335+
serialized = {
336+
"kwargs": {
337+
"model_name": "gpt-4-from-kwargs",
338+
"openai_api_base": "https://kwargs.api.com",
339+
},
340+
"repr": "ChatOpenAI(model='gpt-3.5-from-repr', base_url='https://repr.api.com')",
341+
}
342+
343+
model_name, base_url = extract_model_name_and_base_url(serialized)
344+
345+
assert model_name == "gpt-4-from-kwargs"
346+
assert base_url == "https://kwargs.api.com"
347+
348+
349+
def test_extract_model_and_url_with_missing_values():
350+
"""Test extraction when values are missing."""
351+
# No model or URL
352+
serialized = {"kwargs": {"temperature": 0.7}}
353+
model_name, base_url = extract_model_name_and_base_url(serialized)
354+
assert model_name is None
355+
assert base_url is None
356+
357+
# Only model, no URL
358+
serialized = {"kwargs": {"model_name": "gpt-4"}}
359+
model_name, base_url = extract_model_name_and_base_url(serialized)
360+
assert model_name == "gpt-4"
361+
assert base_url is None
362+
363+
# Only URL, no model
364+
serialized = {"repr": "Provider(endpoint_url='https://example.com')"}
365+
model_name, base_url = extract_model_name_and_base_url(serialized)
366+
assert model_name is None
367+
assert base_url == "https://example.com"
368+
369+
370+
def test_extract_model_and_url_with_empty_values():
371+
"""Test extraction when values are empty strings."""
372+
serialized = {"kwargs": {"model_name": "", "openai_api_base": ""}}
373+
model_name, base_url = extract_model_name_and_base_url(serialized)
374+
assert model_name is None
375+
assert base_url is None
376+
377+
378+
def test_extract_model_and_url_with_empty_serialized_data():
379+
"""Test extraction with empty or minimal serialized dict."""
380+
serialized = {}
381+
model_name, base_url = extract_model_name_and_base_url(serialized)
382+
assert model_name is None
383+
assert base_url is None

0 commit comments

Comments
 (0)