Skip to content

Commit 4a98503

Browse files
committed
chore(logging): Log model name and base URL before invoking LLMs
1 parent b5b7579 commit 4a98503

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from nemoguardrails.logging.callbacks import logging_callbacks
3737
from nemoguardrails.logging.explain import LLMCallInfo
3838

39+
log = logging.getLogger(__name__)
40+
3941

4042
class LLMCallException(Exception):
4143
"""A wrapper around the LLM call invocation exception.
@@ -113,7 +115,7 @@ def get_llm_provider(llm: BaseLanguageModel) -> Optional[str]:
113115
return _infer_provider_from_module(llm)
114116

115117

116-
def _infer_model_name(llm: BaseLanguageModel):
118+
def _infer_model_name(llm: Union[BaseLanguageModel, Runnable]) -> str:
117119
"""Helper to infer the model name based from an LLM instance.
118120
119121
Because not all models implement correctly _identifying_params from LangChain, we have to
@@ -209,13 +211,45 @@ def _prepare_callbacks(
209211
return logging_callbacks
210212

211213

214+
def _log_model_and_base_url(llm: Union[BaseLanguageModel, Runnable]) -> None:
215+
"""Extract and log the model and base URL from an LLM instance."""
216+
model_name = _infer_model_name(llm)
217+
base_url = None
218+
219+
# If llm is a `ChatNIM` instance, we expect its `client` to be an `OpenAI` client with a `base_url` attribute.
220+
if hasattr(llm, "client"):
221+
client = getattr(llm, "client")
222+
if hasattr(client, "base_url"):
223+
base_url = str(client.base_url)
224+
else:
225+
# If llm is a `ChatNVIDIA` instance or other provider, check common attribute names that store the base URL.
226+
for attr in [
227+
"base_url",
228+
"openai_api_base",
229+
"azure_endpoint",
230+
"api_base",
231+
"endpoint",
232+
]:
233+
if hasattr(llm, attr):
234+
value = getattr(llm, attr, None)
235+
if value:
236+
base_url = str(value)
237+
break
238+
239+
if base_url:
240+
log.info(f"Invoking LLM: model={model_name}, url={base_url}")
241+
else:
242+
log.info(f"Invoking LLM: model={model_name}")
243+
244+
212245
async def _invoke_with_string_prompt(
213246
llm: Union[BaseLanguageModel, Runnable],
214247
prompt: str,
215248
callbacks: BaseCallbackManager,
216249
):
217250
"""Invoke LLM with string prompt."""
218251
try:
252+
_log_model_and_base_url(llm)
219253
return await llm.ainvoke(prompt, config=RunnableConfig(callbacks=callbacks))
220254
except Exception as e:
221255
raise LLMCallException(e)
@@ -230,6 +264,7 @@ async def _invoke_with_message_list(
230264
messages = _convert_messages_to_langchain_format(prompt)
231265

232266
try:
267+
_log_model_and_base_url(llm)
233268
return await llm.ainvoke(messages, config=RunnableConfig(callbacks=callbacks))
234269
except Exception as e:
235270
raise LLMCallException(e)

0 commit comments

Comments
 (0)