11import json
22import os
33import sys
4+ from unittest .mock import MagicMock , patch
45
6+ import httpx
57import pytest
68from fastapi .testclient import TestClient
9+
710from litellm .llms .custom_httpx .http_handler import HTTPHandler
8- from unittest .mock import MagicMock , patch
9- import httpx
10-
11+
1112sys .path .insert (
1213 0 , os .path .abspath ("../../.." )
1314) # Adds the parent directory to the system path
@@ -133,4 +134,271 @@ def test_bedrock_non_application_inference_profile_no_encoding():
133134 actual_url = str (call_args .kwargs ["url" ])
134135 assert "application-inference-profile%2F" not in actual_url
135136 assert "anthropic.claude-3-sonnet-20240229-v1:0" in actual_url
137+ assert response .status_code == 200
138+
139+
140+ def test_update_stream_param_based_on_request_body ():
141+ """
142+ Test _update_stream_param_based_on_request_body handles stream parameter correctly.
143+ """
144+ from litellm .proxy .pass_through_endpoints .pass_through_endpoints import (
145+ HttpPassThroughEndpointHelpers ,
146+ )
147+
148+ # Test 1: stream in request body should take precedence
149+ parsed_body = {"stream" : True , "model" : "test-model" }
150+ result = HttpPassThroughEndpointHelpers ._update_stream_param_based_on_request_body (
151+ parsed_body = parsed_body , stream = False
152+ )
153+ assert result is True
154+
155+ # Test 2: no stream in request body should return original stream param
156+ parsed_body = {"model" : "test-model" }
157+ result = HttpPassThroughEndpointHelpers ._update_stream_param_based_on_request_body (
158+ parsed_body = parsed_body , stream = False
159+ )
160+ assert result is False
161+
162+ # Test 3: stream=False in request body should return False
163+ parsed_body = {"stream" : False , "model" : "test-model" }
164+ result = HttpPassThroughEndpointHelpers ._update_stream_param_based_on_request_body (
165+ parsed_body = parsed_body , stream = True
166+ )
167+ assert result is False
168+
169+ # Test 4: no stream param provided, no stream in body
170+ parsed_body = {"model" : "test-model" }
171+ result = HttpPassThroughEndpointHelpers ._update_stream_param_based_on_request_body (
172+ parsed_body = parsed_body , stream = None
173+ )
174+ assert result is None
175+
176+
177+ @pytest .fixture
178+ def mock_request ():
179+ """Create a mock request with headers"""
180+ from typing import Optional
181+
182+ class QueryParams :
183+ def __init__ (self ):
184+ self ._dict = {}
185+
186+ class MockRequest :
187+ def __init__ (
188+ self , headers = None , method = "POST" , request_body : Optional [dict ] = None
189+ ):
190+ self .headers = headers or {}
191+ self .query_params = QueryParams ()
192+ self .method = method
193+ self .request_body = request_body or {}
194+ # Add url attribute that the actual code expects
195+ self .url = "http://localhost:8000/test"
196+
197+ async def body (self ) -> bytes :
198+ return bytes (json .dumps (self .request_body ), "utf-8" )
199+
200+ return MockRequest
201+
202+
203+ @pytest .fixture
204+ def mock_user_api_key_dict ():
205+ """Create a mock user API key dictionary"""
206+ from litellm .proxy ._types import UserAPIKeyAuth
207+ return UserAPIKeyAuth (
208+ api_key = "test-key" ,
209+ user_id = "test-user" ,
210+ team_id = "test-team" ,
211+ end_user_id = "test-user" ,
212+ )
213+
214+
215+ @pytest .mark .asyncio
216+ async def test_pass_through_request_stream_param_override (
217+ mock_request , mock_user_api_key_dict
218+ ):
219+ """
220+ Test that when stream=None is passed as parameter but stream=True
221+ is in request body, the request body value takes precedence and
222+ the eventual POST request uses streaming.
223+ """
224+ from unittest .mock import AsyncMock , Mock , patch
225+
226+ from litellm .proxy .pass_through_endpoints .pass_through_endpoints import (
227+ pass_through_request ,
228+ )
229+
230+ # Create request body with stream=True
231+ request_body = {
232+ "model" : "claude-3-5-sonnet-20241022" ,
233+ "max_tokens" : 256 ,
234+ "messages" : [{"role" : "user" , "content" : "Hello, world" }],
235+ "stream" : True # This should override the function parameter
236+ }
237+
238+ # Create a mock streaming response
239+ mock_response = AsyncMock ()
240+ mock_response .status_code = 200
241+ mock_response .headers = {"content-type" : "text/event-stream" }
242+
243+ # Mock the streaming response behavior
244+ async def mock_aiter_bytes ():
245+ yield b'data: {"content": "Hello"}\n \n '
246+ yield b'data: {"content": "World"}\n \n '
247+ yield b'data: [DONE]\n \n '
248+
249+ mock_response .aiter_bytes = mock_aiter_bytes
250+
251+ # Create mocks for the async client
252+ mock_async_client = AsyncMock ()
253+ mock_request_obj = AsyncMock ()
254+
255+ # Mock build_request to return a request object (it's a sync method)
256+ mock_async_client .build_request = Mock (return_value = mock_request_obj )
257+
258+ # Mock send to return the streaming response
259+ mock_async_client .send .return_value = mock_response
260+
261+ # Mock get_async_httpx_client to return our mock client
262+ mock_client_obj = Mock ()
263+ mock_client_obj .client = mock_async_client
264+
265+ # Create the request
266+ request = mock_request (
267+ headers = {}, method = "POST" , request_body = request_body
268+ )
269+
270+ with patch (
271+ "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client" ,
272+ return_value = mock_client_obj ,
273+ ), patch (
274+ "litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook" ,
275+ return_value = request_body , # Return the request body unchanged
276+ ), patch (
277+ "litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler" ,
278+ new = AsyncMock (), # Mock the success handler
279+ ):
280+ # Call pass_through_request with stream=False parameter
281+ response = await pass_through_request (
282+ request = request ,
283+ target = "https://api.anthropic.com/v1/messages" ,
284+ custom_headers = {"Authorization" : "Bearer test-key" },
285+ user_api_key_dict = mock_user_api_key_dict ,
286+ stream = None , # This should be overridden by request body
287+ )
288+
289+ # Verify that build_request was called (indicating streaming path)
290+ mock_async_client .build_request .assert_called_once_with (
291+ "POST" ,
292+ httpx .URL ("https://api.anthropic.com/v1/messages" ),
293+ json = request_body ,
294+ params = None ,
295+ headers = {
296+ "Authorization" : "Bearer test-key"
297+ },
298+ )
299+
300+ # Verify that send was called with stream=True
301+ mock_async_client .send .assert_called_once_with (
302+ mock_request_obj ,
303+ stream = True # This proves that stream=True from request body was used
304+ )
305+
306+ # Verify that the non-streaming request method was NOT called
307+ mock_async_client .request .assert_not_called ()
308+
309+ # Verify response is a StreamingResponse
310+ from fastapi .responses import StreamingResponse
311+ assert isinstance (response , StreamingResponse )
312+ assert response .status_code == 200
313+
314+
315+ @pytest .mark .asyncio
316+ async def test_pass_through_request_stream_param_no_override (
317+ mock_request , mock_user_api_key_dict
318+ ):
319+ """
320+ Test that when stream=False is passed as parameter and no stream
321+ is in request body, the function parameter is used and
322+ the eventual request uses non-streaming.
323+ """
324+ from unittest .mock import AsyncMock , Mock , patch
325+
326+ from litellm .proxy .pass_through_endpoints .pass_through_endpoints import (
327+ pass_through_request ,
328+ )
329+
330+ # Create request body without stream parameter
331+ request_body = {
332+ "model" : "claude-3-5-sonnet-20241022" ,
333+ "max_tokens" : 256 ,
334+ "messages" : [{"role" : "user" , "content" : "Hello, world" }],
335+ # No stream parameter - should use function parameter stream=False
336+ }
337+
338+ # Create a mock non-streaming response
339+ mock_response = AsyncMock ()
340+ mock_response .status_code = 200
341+ mock_response .headers = {"content-type" : "application/json" }
342+ mock_response ._content = b'{"response": "Hello world"}'
343+
344+ async def mock_aread ():
345+ return mock_response ._content
346+
347+ mock_response .aread = mock_aread
348+
349+ # Create mocks for the async client
350+ mock_async_client = AsyncMock ()
351+
352+ # Mock request to return the non-streaming response
353+ mock_async_client .request .return_value = mock_response
354+
355+ # Mock get_async_httpx_client to return our mock client
356+ mock_client_obj = Mock ()
357+ mock_client_obj .client = mock_async_client
358+
359+ # Create the request
360+ request = mock_request (
361+ headers = {}, method = "POST" , request_body = request_body
362+ )
363+
364+ with patch (
365+ "litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client" ,
366+ return_value = mock_client_obj ,
367+ ), patch (
368+ "litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook" ,
369+ return_value = request_body , # Return the request body unchanged
370+ ), patch (
371+ "litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler" ,
372+ new = AsyncMock (), # Mock the success handler
373+ ):
374+ # Call pass_through_request with stream=False parameter
375+ response = await pass_through_request (
376+ request = request ,
377+ target = "https://api.anthropic.com/v1/messages" ,
378+ custom_headers = {"Authorization" : "Bearer test-key" },
379+ user_api_key_dict = mock_user_api_key_dict ,
380+ stream = False , # Should be used since no stream in request body
381+ )
382+
383+ # Verify that build_request was NOT called (no streaming path)
384+ mock_async_client .build_request .assert_not_called ()
385+
386+ # Verify that send was NOT called (no streaming path)
387+ mock_async_client .send .assert_not_called ()
388+
389+ # Verify that the non-streaming request method WAS called
390+ mock_async_client .request .assert_called_once_with (
391+ method = "POST" ,
392+ url = httpx .URL ("https://api.anthropic.com/v1/messages" ),
393+ headers = {
394+ "Authorization" : "Bearer test-key"
395+ },
396+ params = None ,
397+ json = request_body ,
398+ )
399+
400+ # Verify response is a regular Response (not StreamingResponse)
401+ from fastapi .responses import Response , StreamingResponse
402+ assert not isinstance (response , StreamingResponse )
403+ assert isinstance (response , Response )
136404 assert response .status_code == 200
0 commit comments