diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index fda5e9414840..adedcaf781d5 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -531,6 +531,19 @@ def construct_target_url_with_subpath( subpath = subpath[1:] return base_target + subpath + + @staticmethod + def _update_stream_param_based_on_request_body( + parsed_body: dict, + stream: Optional[bool] = None, + ) -> Optional[bool]: + """ + If stream is provided in the request body, use it. + Otherwise, use the stream parameter passed to the `pass_through_request` function + """ + if "stream" in parsed_body: + return parsed_body.get("stream", stream) + return stream async def pass_through_request( # noqa: PLR0915 @@ -686,6 +699,11 @@ async def pass_through_request( # noqa: PLR0915 "headers": headers, }, ) + stream = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body( + parsed_body=_parsed_body, + stream=stream, + ) + if stream: req = async_client.build_request( "POST", diff --git a/tests/test_litellm/passthrough/test_passthrough_main.py b/tests/test_litellm/passthrough/test_passthrough_main.py index 9b7e20159e44..a2008c2f336f 100644 --- a/tests/test_litellm/passthrough/test_passthrough_main.py +++ b/tests/test_litellm/passthrough/test_passthrough_main.py @@ -1,13 +1,14 @@ import json import os import sys +from unittest.mock import MagicMock, patch +import httpx import pytest from fastapi.testclient import TestClient + from litellm.llms.custom_httpx.http_handler import HTTPHandler -from unittest.mock import MagicMock, patch -import httpx - + sys.path.insert( 0, os.path.abspath("../../..") ) # Adds the parent directory to the system path @@ -133,4 +134,271 @@ def test_bedrock_non_application_inference_profile_no_encoding(): actual_url = str(call_args.kwargs["url"]) assert "application-inference-profile%2F" not in actual_url assert "anthropic.claude-3-sonnet-20240229-v1:0" in actual_url + assert response.status_code == 200 + + +def test_update_stream_param_based_on_request_body(): + """ + Test _update_stream_param_based_on_request_body handles stream parameter correctly. + """ + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + HttpPassThroughEndpointHelpers, + ) + + # Test 1: stream in request body should take precedence + parsed_body = {"stream": True, "model": "test-model"} + result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body( + parsed_body=parsed_body, stream=False + ) + assert result is True + + # Test 2: no stream in request body should return original stream param + parsed_body = {"model": "test-model"} + result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body( + parsed_body=parsed_body, stream=False + ) + assert result is False + + # Test 3: stream=False in request body should return False + parsed_body = {"stream": False, "model": "test-model"} + result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body( + parsed_body=parsed_body, stream=True + ) + assert result is False + + # Test 4: no stream param provided, no stream in body + parsed_body = {"model": "test-model"} + result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body( + parsed_body=parsed_body, stream=None + ) + assert result is None + + +@pytest.fixture +def mock_request(): + """Create a mock request with headers""" + from typing import Optional + + class QueryParams: + def __init__(self): + self._dict = {} + + class MockRequest: + def __init__( + self, headers=None, method="POST", request_body: Optional[dict] = None + ): + self.headers = headers or {} + self.query_params = QueryParams() + self.method = method + self.request_body = request_body or {} + # Add url attribute that the actual code expects + self.url = "http://localhost:8000/test" + + async def body(self) -> bytes: + return bytes(json.dumps(self.request_body), "utf-8") + + return MockRequest + + +@pytest.fixture +def mock_user_api_key_dict(): + """Create a mock user API key dictionary""" + from litellm.proxy._types import UserAPIKeyAuth + return UserAPIKeyAuth( + api_key="test-key", + user_id="test-user", + team_id="test-team", + end_user_id="test-user", + ) + + +@pytest.mark.asyncio +async def test_pass_through_request_stream_param_override( + mock_request, mock_user_api_key_dict +): + """ + Test that when stream=None is passed as parameter but stream=True + is in request body, the request body value takes precedence and + the eventual POST request uses streaming. + """ + from unittest.mock import AsyncMock, Mock, patch + + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + pass_through_request, + ) + + # Create request body with stream=True + request_body = { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello, world"}], + "stream": True # This should override the function parameter + } + + # Create a mock streaming response + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "text/event-stream"} + + # Mock the streaming response behavior + async def mock_aiter_bytes(): + yield b'data: {"content": "Hello"}\n\n' + yield b'data: {"content": "World"}\n\n' + yield b'data: [DONE]\n\n' + + mock_response.aiter_bytes = mock_aiter_bytes + + # Create mocks for the async client + mock_async_client = AsyncMock() + mock_request_obj = AsyncMock() + + # Mock build_request to return a request object (it's a sync method) + mock_async_client.build_request = Mock(return_value=mock_request_obj) + + # Mock send to return the streaming response + mock_async_client.send.return_value = mock_response + + # Mock get_async_httpx_client to return our mock client + mock_client_obj = Mock() + mock_client_obj.client = mock_async_client + + # Create the request + request = mock_request( + headers={}, method="POST", request_body=request_body + ) + + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client", + return_value=mock_client_obj, + ), patch( + "litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook", + return_value=request_body, # Return the request body unchanged + ), patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler", + new=AsyncMock(), # Mock the success handler + ): + # Call pass_through_request with stream=False parameter + response = await pass_through_request( + request=request, + target="https://api.anthropic.com/v1/messages", + custom_headers={"Authorization": "Bearer test-key"}, + user_api_key_dict=mock_user_api_key_dict, + stream=None, # This should be overridden by request body + ) + + # Verify that build_request was called (indicating streaming path) + mock_async_client.build_request.assert_called_once_with( + "POST", + httpx.URL("https://api.anthropic.com/v1/messages"), + json=request_body, + params=None, + headers={ + "Authorization": "Bearer test-key" + }, + ) + + # Verify that send was called with stream=True + mock_async_client.send.assert_called_once_with( + mock_request_obj, + stream=True # This proves that stream=True from request body was used + ) + + # Verify that the non-streaming request method was NOT called + mock_async_client.request.assert_not_called() + + # Verify response is a StreamingResponse + from fastapi.responses import StreamingResponse + assert isinstance(response, StreamingResponse) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_pass_through_request_stream_param_no_override( + mock_request, mock_user_api_key_dict +): + """ + Test that when stream=False is passed as parameter and no stream + is in request body, the function parameter is used and + the eventual request uses non-streaming. + """ + from unittest.mock import AsyncMock, Mock, patch + + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + pass_through_request, + ) + + # Create request body without stream parameter + request_body = { + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello, world"}], + # No stream parameter - should use function parameter stream=False + } + + # Create a mock non-streaming response + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response._content = b'{"response": "Hello world"}' + + async def mock_aread(): + return mock_response._content + + mock_response.aread = mock_aread + + # Create mocks for the async client + mock_async_client = AsyncMock() + + # Mock request to return the non-streaming response + mock_async_client.request.return_value = mock_response + + # Mock get_async_httpx_client to return our mock client + mock_client_obj = Mock() + mock_client_obj.client = mock_async_client + + # Create the request + request = mock_request( + headers={}, method="POST", request_body=request_body + ) + + with patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client", + return_value=mock_client_obj, + ), patch( + "litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook", + return_value=request_body, # Return the request body unchanged + ), patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler", + new=AsyncMock(), # Mock the success handler + ): + # Call pass_through_request with stream=False parameter + response = await pass_through_request( + request=request, + target="https://api.anthropic.com/v1/messages", + custom_headers={"Authorization": "Bearer test-key"}, + user_api_key_dict=mock_user_api_key_dict, + stream=False, # Should be used since no stream in request body + ) + + # Verify that build_request was NOT called (no streaming path) + mock_async_client.build_request.assert_not_called() + + # Verify that send was NOT called (no streaming path) + mock_async_client.send.assert_not_called() + + # Verify that the non-streaming request method WAS called + mock_async_client.request.assert_called_once_with( + method="POST", + url=httpx.URL("https://api.anthropic.com/v1/messages"), + headers={ + "Authorization": "Bearer test-key" + }, + params=None, + json=request_body, + ) + + # Verify response is a regular Response (not StreamingResponse) + from fastapi.responses import Response, StreamingResponse + assert not isinstance(response, StreamingResponse) + assert isinstance(response, Response) assert response.status_code == 200 \ No newline at end of file