Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions litellm/llms/azure/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def completion( # noqa: PLR0915
acompletion: bool = False,
headers: Optional[dict] = None,
client=None,
connect_timeout: Optional[float] = None,
):
if headers:
optional_params["extra_headers"] = headers
Expand Down Expand Up @@ -281,6 +282,7 @@ def completion( # noqa: PLR0915
max_retries=max_retries,
convert_tool_call_to_json_mode=json_mode,
litellm_params=litellm_params,
connect_timeout=connect_timeout,
)
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
Expand Down Expand Up @@ -317,6 +319,11 @@ def completion( # noqa: PLR0915
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# Add connect_timeout to litellm_params if provided
if connect_timeout is not None:
if litellm_params is None:
litellm_params = {}
litellm_params["connect_timeout"] = connect_timeout
# init AzureOpenAI Client
azure_client = self.get_azure_openai_client(
api_version=api_version,
Expand Down Expand Up @@ -387,9 +394,15 @@ async def acompletion(
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI
litellm_params: Optional[dict] = {},
connect_timeout: Optional[float] = None,
):
response = None
try:
# Add connect_timeout to litellm_params if provided
if connect_timeout is not None:
if litellm_params is None:
litellm_params = {}
litellm_params["connect_timeout"] = connect_timeout
# setting Azure client
azure_client = self.get_azure_openai_client(
api_version=api_version,
Expand Down
6 changes: 6 additions & 0 deletions litellm/llms/azure/chat/o_series_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ def completion(
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
connect_timeout: Optional[float] = None,
):
# Add connect_timeout to litellm_params if provided
if connect_timeout is not None:
if litellm_params is None:
litellm_params = {}
litellm_params["connect_timeout"] = connect_timeout
client = self.get_azure_openai_client(
litellm_params=litellm_params,
api_key=api_key,
Expand Down
9 changes: 8 additions & 1 deletion litellm/llms/azure/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,14 @@ def initialize_azure_sdk_client(
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if timeout is not None:
azure_client_params["timeout"] = timeout
# Check if we have a separate connect_timeout
connect_timeout = litellm_params.get("connect_timeout")
if connect_timeout is not None and isinstance(timeout, (int, float)):
# Create httpx.Timeout object with separate connect timeout
import httpx
azure_client_params["timeout"] = httpx.Timeout(timeout=timeout, connect=connect_timeout)
else:
azure_client_params["timeout"] = timeout

if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
Expand Down
6 changes: 6 additions & 0 deletions litellm/llms/azure/completion/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def completion( # noqa: PLR0915
acompletion: bool = False,
headers: Optional[dict] = None,
client=None,
connect_timeout: Optional[float] = None,
):
try:
if model is None or messages is None:
Expand Down Expand Up @@ -143,6 +144,11 @@ def completion( # noqa: PLR0915
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# Add connect_timeout to litellm_params if provided
if connect_timeout is not None:
if litellm_params is None:
litellm_params = {}
litellm_params["connect_timeout"] = connect_timeout
# init AzureOpenAI Client
azure_client = self.get_azure_openai_client(
api_key=api_key,
Expand Down
20 changes: 17 additions & 3 deletions litellm/llms/openai/completion/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def completion(
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
connect_timeout: Optional[float] = None,
):
try:
if headers is None:
Expand Down Expand Up @@ -94,7 +95,7 @@ def completion(
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client, connect_timeout=connect_timeout) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
Expand All @@ -111,11 +112,17 @@ def completion(
)
else:
if client is None:
# Handle connect_timeout by creating httpx.Timeout object if needed
_timeout = timeout
if connect_timeout is not None and isinstance(timeout, (int, float)):
import httpx
_timeout = httpx.Timeout(timeout=timeout, connect=connect_timeout)

openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
timeout=_timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
Expand Down Expand Up @@ -162,14 +169,21 @@ async def acompletion(
max_retries: int,
organization: Optional[str] = None,
client=None,
connect_timeout: Optional[float] = None,
):
try:
if client is None:
# Handle connect_timeout by creating httpx.Timeout object if needed
_timeout = timeout
if connect_timeout is not None and isinstance(timeout, (int, float)):
import httpx
_timeout = httpx.Timeout(timeout=timeout, connect=connect_timeout)

openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
timeout=_timeout,
max_retries=max_retries,
organization=organization,
)
Expand Down
14 changes: 12 additions & 2 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def _get_openai_client(
max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
organization: Optional[str] = None,
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
connect_timeout: Optional[float] = None,
) -> Optional[Union[OpenAI, AsyncOpenAI]]:
client_initialization_params: Dict = locals()
if client is None:
Expand All @@ -375,12 +376,17 @@ def _get_openai_client(
cached_client, AsyncOpenAI
):
return cached_client
# Handle connect_timeout by creating httpx.Timeout object if needed
_timeout = timeout
if connect_timeout is not None and isinstance(timeout, (int, float)):
_timeout = httpx.Timeout(timeout=timeout, connect=connect_timeout)

if is_async:
_new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=OpenAIChatCompletion._get_async_http_client(),
timeout=timeout,
timeout=_timeout,
max_retries=max_retries,
organization=organization,
)
Expand All @@ -389,7 +395,7 @@ def _get_openai_client(
api_key=api_key,
base_url=api_base,
http_client=OpenAIChatCompletion._get_sync_http_client(),
timeout=timeout,
timeout=_timeout,
max_retries=max_retries,
organization=organization,
)
Expand Down Expand Up @@ -522,6 +528,7 @@ def completion( # type: ignore # noqa: PLR0915
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
connect_timeout: Optional[float] = None,
):
super().completion()
try:
Expand Down Expand Up @@ -644,6 +651,7 @@ def completion( # type: ignore # noqa: PLR0915
max_retries=max_retries,
organization=organization,
client=client,
connect_timeout=connect_timeout,
)

## LOGGING
Expand Down Expand Up @@ -793,6 +801,7 @@ async def acompletion(
max_retries=max_retries,
organization=organization,
client=client,
connect_timeout=connect_timeout,
)

## LOGGING
Expand Down Expand Up @@ -964,6 +973,7 @@ async def async_streaming(
max_retries=max_retries,
organization=organization,
client=client,
connect_timeout=connect_timeout,
)
## LOGGING
logging_obj.pre_call(
Expand Down
20 changes: 20 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ async def acompletion(
functions: Optional[List] = None,
function_call: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
connect_timeout: Optional[float] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
Expand Down Expand Up @@ -397,6 +398,7 @@ async def acompletion(
api_key (str, optional): API key (default is None).
model_list (list, optional): List of api base, version, keys
timeout (float, optional): The maximum execution time in seconds for the completion request.
connect_timeout (float, optional): The maximum time in seconds to wait for connection handshaking to complete.

LITELLM Specific Params
mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None).
Expand Down Expand Up @@ -463,6 +465,7 @@ async def acompletion(
"functions": functions,
"function_call": function_call,
"timeout": timeout,
"connect_timeout": connect_timeout,
"temperature": temperature,
"top_p": top_p,
"n": n,
Expand Down Expand Up @@ -877,6 +880,7 @@ def completion( # type: ignore # noqa: PLR0915
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
connect_timeout: Optional[float] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
Expand Down Expand Up @@ -1130,6 +1134,11 @@ def completion( # type: ignore # noqa: PLR0915
elif not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore

### CONNECT TIMEOUT LOGIC ###
connect_timeout = connect_timeout or kwargs.get("connect_timeout", None)
if connect_timeout is not None:
connect_timeout = float(connect_timeout)

### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
if input_cost_per_token is not None and output_cost_per_token is not None:
litellm.register_model(
Expand Down Expand Up @@ -1440,6 +1449,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
acompletion=acompletion,
timeout=timeout, # type: ignore
connect_timeout=connect_timeout,
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
custom_llm_provider=custom_llm_provider,
)
Expand Down Expand Up @@ -1472,6 +1482,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
acompletion=acompletion,
timeout=timeout, # type: ignore
connect_timeout=connect_timeout,
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)

Expand Down Expand Up @@ -1545,6 +1556,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
connect_timeout=connect_timeout,
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)

Expand Down Expand Up @@ -1721,6 +1733,7 @@ def completion( # type: ignore # noqa: PLR0915
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
connect_timeout=connect_timeout,
)

if (
Expand Down Expand Up @@ -1839,6 +1852,7 @@ def completion( # type: ignore # noqa: PLR0915
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
timeout=timeout,
connect_timeout=connect_timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
Expand Down Expand Up @@ -2000,6 +2014,7 @@ def completion( # type: ignore # noqa: PLR0915
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
connect_timeout=connect_timeout,
)
except Exception as e:
## LOGGING - log the original exception returned
Expand Down Expand Up @@ -2208,6 +2223,7 @@ def completion( # type: ignore # noqa: PLR0915
timeout=timeout,
client=client,
custom_llm_provider=custom_llm_provider,
connect_timeout=connect_timeout,
)
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
Expand Down Expand Up @@ -2685,6 +2701,7 @@ def completion( # type: ignore # noqa: PLR0915
timeout=timeout,
custom_llm_provider=custom_llm_provider,
client=client,
connect_timeout=connect_timeout,
api_base=api_base,
extra_headers=extra_headers,
)
Expand Down Expand Up @@ -2983,6 +3000,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
connect_timeout=connect_timeout,
acompletion=acompletion,
client=client,
api_base=api_base,
Expand All @@ -3001,6 +3019,7 @@ def completion( # type: ignore # noqa: PLR0915
litellm_params=litellm_params,
custom_llm_provider="bedrock",
timeout=timeout,
connect_timeout=connect_timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
Expand All @@ -3019,6 +3038,7 @@ def completion( # type: ignore # noqa: PLR0915
litellm_params=litellm_params,
custom_llm_provider="bedrock",
timeout=timeout,
connect_timeout=connect_timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
Expand Down
Loading