Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from nemoguardrails.logging.callbacks import logging_callbacks
from nemoguardrails.logging.explain import LLMCallInfo

log = logging.getLogger(__name__)


class LLMCallException(Exception):
"""A wrapper around the LLM call invocation exception.
Expand Down Expand Up @@ -113,7 +115,7 @@ def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]:
return _infer_provider_from_module(llm)


def _infer_model_name(llm: BaseLanguageModel):
def _infer_model_name(llm: Union[BaseLanguageModel, Runnable]) -> str:
"""Helper to infer the model name based from an LLM instance.

Because not all models implement correctly _identifying_params from LangChain, we have to
Expand Down
19 changes: 19 additions & 0 deletions nemoguardrails/logging/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.logging.processing_log import processing_log_var
from nemoguardrails.logging.stats import LLMStats
from nemoguardrails.logging.utils import extract_model_name_and_base_url
from nemoguardrails.utils import new_uuid

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,6 +65,15 @@ async def on_llm_start(
if explain_info:
explain_info.llm_calls.append(llm_call_info)

# Log model name and base URL
model_name, base_url = extract_model_name_and_base_url(serialized)
if base_url:
log.info(f"Invoking LLM: model={model_name}, url={base_url}")
elif model_name:
log.info(f"Invoking LLM: model={model_name}")
else:
log.info("Invoking LLM")
Comment on lines +68 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The exact same logging logic appears in both on_llm_start (lines 68-75) and on_chat_model_start (lines 118-125). Extract to a helper method to reduce duplication and improve maintainability.


log.info("Invocation Params :: %s", kwargs.get("invocation_params", {}))
log.info(
"Prompt :: %s",
Expand Down Expand Up @@ -105,6 +115,15 @@ async def on_chat_model_start(
if explain_info:
explain_info.llm_calls.append(llm_call_info)

# Log model name and base URL
model_name, base_url = extract_model_name_and_base_url(serialized)
if base_url:
log.info(f"Invoking LLM: model={model_name}, url={base_url}")
elif model_name:
log.info(f"Invoking LLM: model={model_name}")
else:
log.info("Invoking LLM")

type_map = {
"human": "User",
"ai": "Bot",
Expand Down
79 changes: 79 additions & 0 deletions nemoguardrails/logging/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from typing import Any, Dict, Optional

log = logging.getLogger(__name__)


def extract_model_name_and_base_url(
serialized: Dict[str, Any]
) -> tuple[Optional[str], Optional[str]]:
"""Extract model name and base URL from serialized LLM parameters.
Args:
serialized: The serialized LLM configuration
Returns:
A tuple of (model_name, base_url). Either value can be None if not found
"""
model_name = None
base_url = None

# Case 1: Try to extract from kwargs (we expect kwargs to be populated for the `ChatOpenAI` class).
if "kwargs" in serialized:
kwargs = serialized["kwargs"]

# Check for model_name in kwargs (ChatOpenAI attribute)
if "model_name" in kwargs and kwargs["model_name"]:
model_name = str(kwargs["model_name"])

# Check for openai_api_base in kwargs (ChatOpenAI attribute)
if "openai_api_base" in kwargs and kwargs["openai_api_base"]:
base_url = str(kwargs["openai_api_base"])

# Case 2: For other providers, parse `repr`, a string representation of the provider class. We don't have
# a reference to the actual class, so we need to parse the string representation.
if "repr" in serialized and isinstance(serialized["repr"], str):
repr_str = serialized["repr"]

# Extract model name. We expect the property to be formatted like model='...' or model_name='...',
# and check for single and double quotes.
if not model_name:
match = re.search(r"model(?:_name)?=['\"]([^'\"]+)['\"]", repr_str)
if match:
model_name = match.group(1)

# Extract base URL. The propety name may vary between providers, so try common attribute patterns.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Typo: "propety" should be "property"

Suggested change
# Extract base URL. The propety name may vary between providers, so try common attribute patterns.
# Extract base URL. The property name may vary between providers, so try common attribute patterns.

# We expect the property to be formatted like property_name='...', and check for single and double quotes.
if not base_url:
url_attrs = [
"api_base",
"api_host",
"azure_endpoint",
"base_url",
"endpoint",
"endpoint_url",
"openai_api_base",
]
for attr in url_attrs:
match = re.search(rf"{attr}=['\"]([^'\"]+)['\"]", repr_str)
if match:
base_url = match.group(1)
break

return model_name, base_url
120 changes: 120 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nemoguardrails.logging.callbacks import LoggingCallbackHandler
from nemoguardrails.logging.explain import ExplainInfo, LLMCallInfo
from nemoguardrails.logging.stats import LLMStats
from nemoguardrails.logging.utils import extract_model_name_and_base_url


@pytest.mark.asyncio
Expand Down Expand Up @@ -261,3 +262,122 @@ def __init__(self, content, msg_type):
assert logged_prompt is not None
assert "[cyan]Custom[/]" in logged_prompt
assert "[cyan]Function[/]" in logged_prompt


def test_extract_model_and_url_from_kwargs():
"""Test extracting model_name and openai_api_base from kwargs (ChatOpenAI case)."""
serialized = {
"kwargs": {
"model_name": "gpt-4",
"openai_api_base": "https://api.openai.com/v1",
"temperature": 0.7,
}
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "gpt-4"
assert base_url == "https://api.openai.com/v1"


def test_extract_model_and_url_from_repr():
"""Test extracting from repr string (ChatNIM case)."""
# Property values in single-quotes
serialized = {
"kwargs": {"temperature": 0.1},
"repr": "ChatNIM(model='meta/llama-3.3-70b-instruct', client=<openai.OpenAI object at 0x10d8e4e90>, endpoint_url='https://nim.int.aire.nvidia.com/v1')",
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "meta/llama-3.3-70b-instruct"
assert base_url == "https://nim.int.aire.nvidia.com/v1"

# Property values in double-quotes
serialized = {
"repr": 'ChatOpenAI(model="gpt-3.5-turbo", base_url="https://custom.api.com/v1")'
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "gpt-3.5-turbo"
assert base_url == "https://custom.api.com/v1"

# Model is stored in the `model_name` property
serialized = {
"repr": "SomeProvider(model_name='custom-model-v2', api_base='https://example.com')"
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "custom-model-v2"
assert base_url == "https://example.com"


def test_extract_model_and_url_from_various_url_properties():
"""Test extracting various URL property names."""
test_cases = [
("api_base='https://api1.com'", "https://api1.com"),
("api_host='https://api2.com'", "https://api2.com"),
("azure_endpoint='https://azure.com'", "https://azure.com"),
("endpoint='https://endpoint.com'", "https://endpoint.com"),
("openai_api_base='https://openai.com'", "https://openai.com"),
]

for url_pattern, expected_url in test_cases:
serialized = {"repr": f"Provider(model='test-model', {url_pattern})"}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert base_url == expected_url, f"Failed for pattern: {url_pattern}"


def test_extract_model_and_url_kwargs_priority_over_repr():
"""Test that kwargs values, if present, take priority over repr values."""
serialized = {
"kwargs": {
"model_name": "gpt-4-from-kwargs",
"openai_api_base": "https://kwargs.api.com",
},
"repr": "ChatOpenAI(model='gpt-3.5-from-repr', base_url='https://repr.api.com')",
}

model_name, base_url = extract_model_name_and_base_url(serialized)

assert model_name == "gpt-4-from-kwargs"
assert base_url == "https://kwargs.api.com"


def test_extract_model_and_url_with_missing_values():
"""Test extraction when values are missing."""
# No model or URL
serialized = {"kwargs": {"temperature": 0.7}}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name is None
assert base_url is None

# Only model, no URL
serialized = {"kwargs": {"model_name": "gpt-4"}}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name == "gpt-4"
assert base_url is None

# Only URL, no model
serialized = {"repr": "Provider(endpoint_url='https://example.com')"}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name is None
assert base_url == "https://example.com"


def test_extract_model_and_url_with_empty_values():
"""Test extraction when values are empty strings."""
serialized = {"kwargs": {"model_name": "", "openai_api_base": ""}}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name is None
assert base_url is None


def test_extract_model_and_url_with_empty_serialized_data():
"""Test extraction with empty or minimal serialized dict."""
serialized = {}
model_name, base_url = extract_model_name_and_base_url(serialized)
assert model_name is None
assert base_url is None