diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 5ee9065f5e12..5d09b5f16024 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -198,6 +198,7 @@ def completion( # noqa: PLR0915 acompletion: bool = False, headers: Optional[dict] = None, client=None, + connect_timeout: Optional[float] = None, ): if headers: optional_params["extra_headers"] = headers @@ -281,6 +282,7 @@ def completion( # noqa: PLR0915 max_retries=max_retries, convert_tool_call_to_json_mode=json_mode, litellm_params=litellm_params, + connect_timeout=connect_timeout, ) elif "stream" in optional_params and optional_params["stream"] is True: return self.streaming( @@ -317,6 +319,11 @@ def completion( # noqa: PLR0915 raise AzureOpenAIError( status_code=422, message="max retries must be an int" ) + # Add connect_timeout to litellm_params if provided + if connect_timeout is not None: + if litellm_params is None: + litellm_params = {} + litellm_params["connect_timeout"] = connect_timeout # init AzureOpenAI Client azure_client = self.get_azure_openai_client( api_version=api_version, @@ -387,9 +394,15 @@ async def acompletion( convert_tool_call_to_json_mode: Optional[bool] = None, client=None, # this is the AsyncAzureOpenAI litellm_params: Optional[dict] = {}, + connect_timeout: Optional[float] = None, ): response = None try: + # Add connect_timeout to litellm_params if provided + if connect_timeout is not None: + if litellm_params is None: + litellm_params = {} + litellm_params["connect_timeout"] = connect_timeout # setting Azure client azure_client = self.get_azure_openai_client( api_version=api_version, diff --git a/litellm/llms/azure/chat/o_series_handler.py b/litellm/llms/azure/chat/o_series_handler.py index 2f3e9e639968..bc1044ffda21 100644 --- a/litellm/llms/azure/chat/o_series_handler.py +++ b/litellm/llms/azure/chat/o_series_handler.py @@ -38,7 +38,13 @@ def completion( organization: Optional[str] = None, custom_llm_provider: Optional[str] = None, drop_params: Optional[bool] = None, + connect_timeout: Optional[float] = None, ): + # Add connect_timeout to litellm_params if provided + if connect_timeout is not None: + if litellm_params is None: + litellm_params = {} + litellm_params["connect_timeout"] = connect_timeout client = self.get_azure_openai_client( litellm_params=litellm_params, api_key=api_key, diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index 94abd2f814e7..ccc90f3d3287 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -590,7 +590,14 @@ def initialize_azure_sdk_client( if max_retries is not None: azure_client_params["max_retries"] = max_retries if timeout is not None: - azure_client_params["timeout"] = timeout + # Check if we have a separate connect_timeout + connect_timeout = litellm_params.get("connect_timeout") + if connect_timeout is not None and isinstance(timeout, (int, float)): + # Create httpx.Timeout object with separate connect timeout + import httpx + azure_client_params["timeout"] = httpx.Timeout(timeout=timeout, connect=connect_timeout) + else: + azure_client_params["timeout"] = timeout if azure_ad_token_provider is not None: azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider diff --git a/litellm/llms/azure/completion/handler.py b/litellm/llms/azure/completion/handler.py index a44f90457124..0ae354d9cd72 100644 --- a/litellm/llms/azure/completion/handler.py +++ b/litellm/llms/azure/completion/handler.py @@ -45,6 +45,7 @@ def completion( # noqa: PLR0915 acompletion: bool = False, headers: Optional[dict] = None, client=None, + connect_timeout: Optional[float] = None, ): try: if model is None or messages is None: @@ -143,6 +144,11 @@ def completion( # noqa: PLR0915 raise AzureOpenAIError( status_code=422, message="max retries must be an int" ) + # Add connect_timeout to litellm_params if provided + if connect_timeout is not None: + if litellm_params is None: + litellm_params = {} + litellm_params["connect_timeout"] = connect_timeout # init AzureOpenAI Client azure_client = self.get_azure_openai_client( api_key=api_key, diff --git a/litellm/llms/openai/completion/handler.py b/litellm/llms/openai/completion/handler.py index fa31c487cd27..7b3b3714448b 100644 --- a/litellm/llms/openai/completion/handler.py +++ b/litellm/llms/openai/completion/handler.py @@ -47,6 +47,7 @@ def completion( client=None, organization: Optional[str] = None, headers: Optional[dict] = None, + connect_timeout: Optional[float] = None, ): try: if headers is None: @@ -94,7 +95,7 @@ def completion( organization=organization, ) else: - return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore + return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client, connect_timeout=connect_timeout) # type: ignore elif optional_params.get("stream", False): return self.streaming( logging_obj=logging_obj, @@ -111,11 +112,17 @@ def completion( ) else: if client is None: + # Handle connect_timeout by creating httpx.Timeout object if needed + _timeout = timeout + if connect_timeout is not None and isinstance(timeout, (int, float)): + import httpx + _timeout = httpx.Timeout(timeout=timeout, connect=connect_timeout) + openai_client = OpenAI( api_key=api_key, base_url=api_base, http_client=litellm.client_session, - timeout=timeout, + timeout=_timeout, max_retries=max_retries, # type: ignore organization=organization, ) @@ -162,14 +169,21 @@ async def acompletion( max_retries: int, organization: Optional[str] = None, client=None, + connect_timeout: Optional[float] = None, ): try: if client is None: + # Handle connect_timeout by creating httpx.Timeout object if needed + _timeout = timeout + if connect_timeout is not None and isinstance(timeout, (int, float)): + import httpx + _timeout = httpx.Timeout(timeout=timeout, connect=connect_timeout) + openai_aclient = AsyncOpenAI( api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, - timeout=timeout, + timeout=_timeout, max_retries=max_retries, organization=organization, ) diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index 1f3cf24457d8..40f0a7a24e17 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -355,6 +355,7 @@ def _get_openai_client( max_retries: Optional[int] = DEFAULT_MAX_RETRIES, organization: Optional[str] = None, client: Optional[Union[OpenAI, AsyncOpenAI]] = None, + connect_timeout: Optional[float] = None, ) -> Optional[Union[OpenAI, AsyncOpenAI]]: client_initialization_params: Dict = locals() if client is None: @@ -375,12 +376,17 @@ def _get_openai_client( cached_client, AsyncOpenAI ): return cached_client + # Handle connect_timeout by creating httpx.Timeout object if needed + _timeout = timeout + if connect_timeout is not None and isinstance(timeout, (int, float)): + _timeout = httpx.Timeout(timeout=timeout, connect=connect_timeout) + if is_async: _new_client: Union[OpenAI, AsyncOpenAI] = AsyncOpenAI( api_key=api_key, base_url=api_base, http_client=OpenAIChatCompletion._get_async_http_client(), - timeout=timeout, + timeout=_timeout, max_retries=max_retries, organization=organization, ) @@ -389,7 +395,7 @@ def _get_openai_client( api_key=api_key, base_url=api_base, http_client=OpenAIChatCompletion._get_sync_http_client(), - timeout=timeout, + timeout=_timeout, max_retries=max_retries, organization=organization, ) @@ -522,6 +528,7 @@ def completion( # type: ignore # noqa: PLR0915 organization: Optional[str] = None, custom_llm_provider: Optional[str] = None, drop_params: Optional[bool] = None, + connect_timeout: Optional[float] = None, ): super().completion() try: @@ -644,6 +651,7 @@ def completion( # type: ignore # noqa: PLR0915 max_retries=max_retries, organization=organization, client=client, + connect_timeout=connect_timeout, ) ## LOGGING @@ -793,6 +801,7 @@ async def acompletion( max_retries=max_retries, organization=organization, client=client, + connect_timeout=connect_timeout, ) ## LOGGING @@ -964,6 +973,7 @@ async def async_streaming( max_retries=max_retries, organization=organization, client=client, + connect_timeout=connect_timeout, ) ## LOGGING logging_obj.pre_call( diff --git a/litellm/main.py b/litellm/main.py index 339d9e144062..13260dd392c6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -330,6 +330,7 @@ async def acompletion( functions: Optional[List] = None, function_call: Optional[str] = None, timeout: Optional[Union[float, int]] = None, + connect_timeout: Optional[float] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, n: Optional[int] = None, @@ -397,6 +398,7 @@ async def acompletion( api_key (str, optional): API key (default is None). model_list (list, optional): List of api base, version, keys timeout (float, optional): The maximum execution time in seconds for the completion request. + connect_timeout (float, optional): The maximum time in seconds to wait for connection handshaking to complete. LITELLM Specific Params mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). @@ -463,6 +465,7 @@ async def acompletion( "functions": functions, "function_call": function_call, "timeout": timeout, + "connect_timeout": connect_timeout, "temperature": temperature, "top_p": top_p, "n": n, @@ -877,6 +880,7 @@ def completion( # type: ignore # noqa: PLR0915 # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create messages: List = [], timeout: Optional[Union[float, str, httpx.Timeout]] = None, + connect_timeout: Optional[float] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, n: Optional[int] = None, @@ -1130,6 +1134,11 @@ def completion( # type: ignore # noqa: PLR0915 elif not isinstance(timeout, httpx.Timeout): timeout = float(timeout) # type: ignore + ### CONNECT TIMEOUT LOGIC ### + connect_timeout = connect_timeout or kwargs.get("connect_timeout", None) + if connect_timeout is not None: + connect_timeout = float(connect_timeout) + ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### if input_cost_per_token is not None and output_cost_per_token is not None: litellm.register_model( @@ -1440,6 +1449,7 @@ def completion( # type: ignore # noqa: PLR0915 logging_obj=logging, acompletion=acompletion, timeout=timeout, # type: ignore + connect_timeout=connect_timeout, client=client, # pass AsyncAzureOpenAI, AzureOpenAI client custom_llm_provider=custom_llm_provider, ) @@ -1472,6 +1482,7 @@ def completion( # type: ignore # noqa: PLR0915 logging_obj=logging, acompletion=acompletion, timeout=timeout, # type: ignore + connect_timeout=connect_timeout, client=client, # pass AsyncAzureOpenAI, AzureOpenAI client ) @@ -1545,6 +1556,7 @@ def completion( # type: ignore # noqa: PLR0915 logging_obj=logging, acompletion=acompletion, timeout=timeout, + connect_timeout=connect_timeout, client=client, # pass AsyncAzureOpenAI, AzureOpenAI client ) @@ -1721,6 +1733,7 @@ def completion( # type: ignore # noqa: PLR0915 litellm_params=litellm_params, logger_fn=logger_fn, timeout=timeout, # type: ignore + connect_timeout=connect_timeout, ) if ( @@ -1839,6 +1852,7 @@ def completion( # type: ignore # noqa: PLR0915 litellm_params=litellm_params, custom_llm_provider=custom_llm_provider, timeout=timeout, + connect_timeout=connect_timeout, headers=headers, encoding=encoding, api_key=api_key, @@ -2000,6 +2014,7 @@ def completion( # type: ignore # noqa: PLR0915 client=client, # pass AsyncOpenAI, OpenAI client organization=organization, custom_llm_provider=custom_llm_provider, + connect_timeout=connect_timeout, ) except Exception as e: ## LOGGING - log the original exception returned @@ -2208,6 +2223,7 @@ def completion( # type: ignore # noqa: PLR0915 timeout=timeout, client=client, custom_llm_provider=custom_llm_provider, + connect_timeout=connect_timeout, ) if optional_params.get("stream", False) or acompletion is True: ## LOGGING @@ -2685,6 +2701,7 @@ def completion( # type: ignore # noqa: PLR0915 timeout=timeout, custom_llm_provider=custom_llm_provider, client=client, + connect_timeout=connect_timeout, api_base=api_base, extra_headers=extra_headers, ) @@ -2983,6 +3000,7 @@ def completion( # type: ignore # noqa: PLR0915 logging_obj=logging, extra_headers=extra_headers, timeout=timeout, + connect_timeout=connect_timeout, acompletion=acompletion, client=client, api_base=api_base, @@ -3001,6 +3019,7 @@ def completion( # type: ignore # noqa: PLR0915 litellm_params=litellm_params, custom_llm_provider="bedrock", timeout=timeout, + connect_timeout=connect_timeout, headers=headers, encoding=encoding, api_key=api_key, @@ -3019,6 +3038,7 @@ def completion( # type: ignore # noqa: PLR0915 litellm_params=litellm_params, custom_llm_provider="bedrock", timeout=timeout, + connect_timeout=connect_timeout, headers=headers, encoding=encoding, api_key=api_key,