Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,19 @@ def construct_target_url_with_subpath(
subpath = subpath[1:]

return base_target + subpath

@staticmethod
def _update_stream_param_based_on_request_body(
parsed_body: dict,
stream: Optional[bool] = None,
) -> Optional[bool]:
"""
If stream is provided in the request body, use it.
Otherwise, use the stream parameter passed to the `pass_through_request` function
"""
if "stream" in parsed_body:
return parsed_body.get("stream", stream)
return stream


async def pass_through_request( # noqa: PLR0915
Expand Down Expand Up @@ -686,6 +699,11 @@ async def pass_through_request( # noqa: PLR0915
"headers": headers,
},
)
stream = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
parsed_body=_parsed_body,
stream=stream,
)

if stream:
req = async_client.build_request(
"POST",
Expand Down
274 changes: 271 additions & 3 deletions tests/test_litellm/passthrough/test_passthrough_main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import os
import sys
from unittest.mock import MagicMock, patch

import httpx
import pytest
from fastapi.testclient import TestClient

from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import MagicMock, patch
import httpx


sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
Expand Down Expand Up @@ -133,4 +134,271 @@ def test_bedrock_non_application_inference_profile_no_encoding():
actual_url = str(call_args.kwargs["url"])
assert "application-inference-profile%2F" not in actual_url
assert "anthropic.claude-3-sonnet-20240229-v1:0" in actual_url
assert response.status_code == 200


def test_update_stream_param_based_on_request_body():
"""
Test _update_stream_param_based_on_request_body handles stream parameter correctly.
"""
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
HttpPassThroughEndpointHelpers,
)

# Test 1: stream in request body should take precedence
parsed_body = {"stream": True, "model": "test-model"}
result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
parsed_body=parsed_body, stream=False
)
assert result is True

# Test 2: no stream in request body should return original stream param
parsed_body = {"model": "test-model"}
result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
parsed_body=parsed_body, stream=False
)
assert result is False

# Test 3: stream=False in request body should return False
parsed_body = {"stream": False, "model": "test-model"}
result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
parsed_body=parsed_body, stream=True
)
assert result is False

# Test 4: no stream param provided, no stream in body
parsed_body = {"model": "test-model"}
result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
parsed_body=parsed_body, stream=None
)
assert result is None


@pytest.fixture
def mock_request():
"""Create a mock request with headers"""
from typing import Optional

class QueryParams:
def __init__(self):
self._dict = {}

class MockRequest:
def __init__(
self, headers=None, method="POST", request_body: Optional[dict] = None
):
self.headers = headers or {}
self.query_params = QueryParams()
self.method = method
self.request_body = request_body or {}
# Add url attribute that the actual code expects
self.url = "http://localhost:8000/test"

async def body(self) -> bytes:
return bytes(json.dumps(self.request_body), "utf-8")

return MockRequest


@pytest.fixture
def mock_user_api_key_dict():
"""Create a mock user API key dictionary"""
from litellm.proxy._types import UserAPIKeyAuth
return UserAPIKeyAuth(
api_key="test-key",
user_id="test-user",
team_id="test-team",
end_user_id="test-user",
)


@pytest.mark.asyncio
async def test_pass_through_request_stream_param_override(
mock_request, mock_user_api_key_dict
):
"""
Test that when stream=None is passed as parameter but stream=True
is in request body, the request body value takes precedence and
the eventual POST request uses streaming.
"""
from unittest.mock import AsyncMock, Mock, patch

from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)

# Create request body with stream=True
request_body = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 256,
"messages": [{"role": "user", "content": "Hello, world"}],
"stream": True # This should override the function parameter
}

# Create a mock streaming response
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "text/event-stream"}

# Mock the streaming response behavior
async def mock_aiter_bytes():
yield b'data: {"content": "Hello"}\n\n'
yield b'data: {"content": "World"}\n\n'
yield b'data: [DONE]\n\n'

mock_response.aiter_bytes = mock_aiter_bytes

# Create mocks for the async client
mock_async_client = AsyncMock()
mock_request_obj = AsyncMock()

# Mock build_request to return a request object (it's a sync method)
mock_async_client.build_request = Mock(return_value=mock_request_obj)

# Mock send to return the streaming response
mock_async_client.send.return_value = mock_response

# Mock get_async_httpx_client to return our mock client
mock_client_obj = Mock()
mock_client_obj.client = mock_async_client

# Create the request
request = mock_request(
headers={}, method="POST", request_body=request_body
)

with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
return_value=mock_client_obj,
), patch(
"litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook",
return_value=request_body, # Return the request body unchanged
), patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler",
new=AsyncMock(), # Mock the success handler
):
# Call pass_through_request with stream=False parameter
response = await pass_through_request(
request=request,
target="https://api.anthropic.com/v1/messages",
custom_headers={"Authorization": "Bearer test-key"},
user_api_key_dict=mock_user_api_key_dict,
stream=None, # This should be overridden by request body
)

# Verify that build_request was called (indicating streaming path)
mock_async_client.build_request.assert_called_once_with(
"POST",
httpx.URL("https://api.anthropic.com/v1/messages"),
json=request_body,
params=None,
headers={
"Authorization": "Bearer test-key"
},
)

# Verify that send was called with stream=True
mock_async_client.send.assert_called_once_with(
mock_request_obj,
stream=True # This proves that stream=True from request body was used
)

# Verify that the non-streaming request method was NOT called
mock_async_client.request.assert_not_called()

# Verify response is a StreamingResponse
from fastapi.responses import StreamingResponse
assert isinstance(response, StreamingResponse)
assert response.status_code == 200


@pytest.mark.asyncio
async def test_pass_through_request_stream_param_no_override(
mock_request, mock_user_api_key_dict
):
"""
Test that when stream=False is passed as parameter and no stream
is in request body, the function parameter is used and
the eventual request uses non-streaming.
"""
from unittest.mock import AsyncMock, Mock, patch

from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
pass_through_request,
)

# Create request body without stream parameter
request_body = {
"model": "claude-3-5-sonnet-20241022",
"max_tokens": 256,
"messages": [{"role": "user", "content": "Hello, world"}],
# No stream parameter - should use function parameter stream=False
}

# Create a mock non-streaming response
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_response._content = b'{"response": "Hello world"}'

async def mock_aread():
return mock_response._content

mock_response.aread = mock_aread

# Create mocks for the async client
mock_async_client = AsyncMock()

# Mock request to return the non-streaming response
mock_async_client.request.return_value = mock_response

# Mock get_async_httpx_client to return our mock client
mock_client_obj = Mock()
mock_client_obj.client = mock_async_client

# Create the request
request = mock_request(
headers={}, method="POST", request_body=request_body
)

with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
return_value=mock_client_obj,
), patch(
"litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook",
return_value=request_body, # Return the request body unchanged
), patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler",
new=AsyncMock(), # Mock the success handler
):
# Call pass_through_request with stream=False parameter
response = await pass_through_request(
request=request,
target="https://api.anthropic.com/v1/messages",
custom_headers={"Authorization": "Bearer test-key"},
user_api_key_dict=mock_user_api_key_dict,
stream=False, # Should be used since no stream in request body
)

# Verify that build_request was NOT called (no streaming path)
mock_async_client.build_request.assert_not_called()

# Verify that send was NOT called (no streaming path)
mock_async_client.send.assert_not_called()

# Verify that the non-streaming request method WAS called
mock_async_client.request.assert_called_once_with(
method="POST",
url=httpx.URL("https://api.anthropic.com/v1/messages"),
headers={
"Authorization": "Bearer test-key"
},
params=None,
json=request_body,
)

# Verify response is a regular Response (not StreamingResponse)
from fastapi.responses import Response, StreamingResponse
assert not isinstance(response, StreamingResponse)
assert isinstance(response, Response)
assert response.status_code == 200
Loading