Skip to content

Commit c7b0c57

Browse files
authored
[Bug Fix] Azure Passthrough request with streaming (#13831)
* fix: _update_stream_param_based_on_request_body * test_update_stream_param_based_on_request_body * test_pass_through_request_stream_param_override
1 parent 433d1a4 commit c7b0c57

File tree

2 files changed

+289
-3
lines changed

2 files changed

+289
-3
lines changed

litellm/proxy/pass_through_endpoints/pass_through_endpoints.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,19 @@ def construct_target_url_with_subpath(
531531
subpath = subpath[1:]
532532

533533
return base_target + subpath
534+
535+
@staticmethod
536+
def _update_stream_param_based_on_request_body(
537+
parsed_body: dict,
538+
stream: Optional[bool] = None,
539+
) -> Optional[bool]:
540+
"""
541+
If stream is provided in the request body, use it.
542+
Otherwise, use the stream parameter passed to the `pass_through_request` function
543+
"""
544+
if "stream" in parsed_body:
545+
return parsed_body.get("stream", stream)
546+
return stream
534547

535548

536549
async def pass_through_request( # noqa: PLR0915
@@ -686,6 +699,11 @@ async def pass_through_request( # noqa: PLR0915
686699
"headers": headers,
687700
},
688701
)
702+
stream = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
703+
parsed_body=_parsed_body,
704+
stream=stream,
705+
)
706+
689707
if stream:
690708
req = async_client.build_request(
691709
"POST",

tests/test_litellm/passthrough/test_passthrough_main.py

Lines changed: 271 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import json
22
import os
33
import sys
4+
from unittest.mock import MagicMock, patch
45

6+
import httpx
57
import pytest
68
from fastapi.testclient import TestClient
9+
710
from litellm.llms.custom_httpx.http_handler import HTTPHandler
8-
from unittest.mock import MagicMock, patch
9-
import httpx
10-
11+
1112
sys.path.insert(
1213
0, os.path.abspath("../../..")
1314
) # Adds the parent directory to the system path
@@ -133,4 +134,271 @@ def test_bedrock_non_application_inference_profile_no_encoding():
133134
actual_url = str(call_args.kwargs["url"])
134135
assert "application-inference-profile%2F" not in actual_url
135136
assert "anthropic.claude-3-sonnet-20240229-v1:0" in actual_url
137+
assert response.status_code == 200
138+
139+
140+
def test_update_stream_param_based_on_request_body():
141+
"""
142+
Test _update_stream_param_based_on_request_body handles stream parameter correctly.
143+
"""
144+
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
145+
HttpPassThroughEndpointHelpers,
146+
)
147+
148+
# Test 1: stream in request body should take precedence
149+
parsed_body = {"stream": True, "model": "test-model"}
150+
result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
151+
parsed_body=parsed_body, stream=False
152+
)
153+
assert result is True
154+
155+
# Test 2: no stream in request body should return original stream param
156+
parsed_body = {"model": "test-model"}
157+
result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
158+
parsed_body=parsed_body, stream=False
159+
)
160+
assert result is False
161+
162+
# Test 3: stream=False in request body should return False
163+
parsed_body = {"stream": False, "model": "test-model"}
164+
result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
165+
parsed_body=parsed_body, stream=True
166+
)
167+
assert result is False
168+
169+
# Test 4: no stream param provided, no stream in body
170+
parsed_body = {"model": "test-model"}
171+
result = HttpPassThroughEndpointHelpers._update_stream_param_based_on_request_body(
172+
parsed_body=parsed_body, stream=None
173+
)
174+
assert result is None
175+
176+
177+
@pytest.fixture
178+
def mock_request():
179+
"""Create a mock request with headers"""
180+
from typing import Optional
181+
182+
class QueryParams:
183+
def __init__(self):
184+
self._dict = {}
185+
186+
class MockRequest:
187+
def __init__(
188+
self, headers=None, method="POST", request_body: Optional[dict] = None
189+
):
190+
self.headers = headers or {}
191+
self.query_params = QueryParams()
192+
self.method = method
193+
self.request_body = request_body or {}
194+
# Add url attribute that the actual code expects
195+
self.url = "http://localhost:8000/test"
196+
197+
async def body(self) -> bytes:
198+
return bytes(json.dumps(self.request_body), "utf-8")
199+
200+
return MockRequest
201+
202+
203+
@pytest.fixture
204+
def mock_user_api_key_dict():
205+
"""Create a mock user API key dictionary"""
206+
from litellm.proxy._types import UserAPIKeyAuth
207+
return UserAPIKeyAuth(
208+
api_key="test-key",
209+
user_id="test-user",
210+
team_id="test-team",
211+
end_user_id="test-user",
212+
)
213+
214+
215+
@pytest.mark.asyncio
216+
async def test_pass_through_request_stream_param_override(
217+
mock_request, mock_user_api_key_dict
218+
):
219+
"""
220+
Test that when stream=None is passed as parameter but stream=True
221+
is in request body, the request body value takes precedence and
222+
the eventual POST request uses streaming.
223+
"""
224+
from unittest.mock import AsyncMock, Mock, patch
225+
226+
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
227+
pass_through_request,
228+
)
229+
230+
# Create request body with stream=True
231+
request_body = {
232+
"model": "claude-3-5-sonnet-20241022",
233+
"max_tokens": 256,
234+
"messages": [{"role": "user", "content": "Hello, world"}],
235+
"stream": True # This should override the function parameter
236+
}
237+
238+
# Create a mock streaming response
239+
mock_response = AsyncMock()
240+
mock_response.status_code = 200
241+
mock_response.headers = {"content-type": "text/event-stream"}
242+
243+
# Mock the streaming response behavior
244+
async def mock_aiter_bytes():
245+
yield b'data: {"content": "Hello"}\n\n'
246+
yield b'data: {"content": "World"}\n\n'
247+
yield b'data: [DONE]\n\n'
248+
249+
mock_response.aiter_bytes = mock_aiter_bytes
250+
251+
# Create mocks for the async client
252+
mock_async_client = AsyncMock()
253+
mock_request_obj = AsyncMock()
254+
255+
# Mock build_request to return a request object (it's a sync method)
256+
mock_async_client.build_request = Mock(return_value=mock_request_obj)
257+
258+
# Mock send to return the streaming response
259+
mock_async_client.send.return_value = mock_response
260+
261+
# Mock get_async_httpx_client to return our mock client
262+
mock_client_obj = Mock()
263+
mock_client_obj.client = mock_async_client
264+
265+
# Create the request
266+
request = mock_request(
267+
headers={}, method="POST", request_body=request_body
268+
)
269+
270+
with patch(
271+
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
272+
return_value=mock_client_obj,
273+
), patch(
274+
"litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook",
275+
return_value=request_body, # Return the request body unchanged
276+
), patch(
277+
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler",
278+
new=AsyncMock(), # Mock the success handler
279+
):
280+
# Call pass_through_request with stream=False parameter
281+
response = await pass_through_request(
282+
request=request,
283+
target="https://api.anthropic.com/v1/messages",
284+
custom_headers={"Authorization": "Bearer test-key"},
285+
user_api_key_dict=mock_user_api_key_dict,
286+
stream=None, # This should be overridden by request body
287+
)
288+
289+
# Verify that build_request was called (indicating streaming path)
290+
mock_async_client.build_request.assert_called_once_with(
291+
"POST",
292+
httpx.URL("https://api.anthropic.com/v1/messages"),
293+
json=request_body,
294+
params=None,
295+
headers={
296+
"Authorization": "Bearer test-key"
297+
},
298+
)
299+
300+
# Verify that send was called with stream=True
301+
mock_async_client.send.assert_called_once_with(
302+
mock_request_obj,
303+
stream=True # This proves that stream=True from request body was used
304+
)
305+
306+
# Verify that the non-streaming request method was NOT called
307+
mock_async_client.request.assert_not_called()
308+
309+
# Verify response is a StreamingResponse
310+
from fastapi.responses import StreamingResponse
311+
assert isinstance(response, StreamingResponse)
312+
assert response.status_code == 200
313+
314+
315+
@pytest.mark.asyncio
316+
async def test_pass_through_request_stream_param_no_override(
317+
mock_request, mock_user_api_key_dict
318+
):
319+
"""
320+
Test that when stream=False is passed as parameter and no stream
321+
is in request body, the function parameter is used and
322+
the eventual request uses non-streaming.
323+
"""
324+
from unittest.mock import AsyncMock, Mock, patch
325+
326+
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
327+
pass_through_request,
328+
)
329+
330+
# Create request body without stream parameter
331+
request_body = {
332+
"model": "claude-3-5-sonnet-20241022",
333+
"max_tokens": 256,
334+
"messages": [{"role": "user", "content": "Hello, world"}],
335+
# No stream parameter - should use function parameter stream=False
336+
}
337+
338+
# Create a mock non-streaming response
339+
mock_response = AsyncMock()
340+
mock_response.status_code = 200
341+
mock_response.headers = {"content-type": "application/json"}
342+
mock_response._content = b'{"response": "Hello world"}'
343+
344+
async def mock_aread():
345+
return mock_response._content
346+
347+
mock_response.aread = mock_aread
348+
349+
# Create mocks for the async client
350+
mock_async_client = AsyncMock()
351+
352+
# Mock request to return the non-streaming response
353+
mock_async_client.request.return_value = mock_response
354+
355+
# Mock get_async_httpx_client to return our mock client
356+
mock_client_obj = Mock()
357+
mock_client_obj.client = mock_async_client
358+
359+
# Create the request
360+
request = mock_request(
361+
headers={}, method="POST", request_body=request_body
362+
)
363+
364+
with patch(
365+
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
366+
return_value=mock_client_obj,
367+
), patch(
368+
"litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook",
369+
return_value=request_body, # Return the request body unchanged
370+
), patch(
371+
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler",
372+
new=AsyncMock(), # Mock the success handler
373+
):
374+
# Call pass_through_request with stream=False parameter
375+
response = await pass_through_request(
376+
request=request,
377+
target="https://api.anthropic.com/v1/messages",
378+
custom_headers={"Authorization": "Bearer test-key"},
379+
user_api_key_dict=mock_user_api_key_dict,
380+
stream=False, # Should be used since no stream in request body
381+
)
382+
383+
# Verify that build_request was NOT called (no streaming path)
384+
mock_async_client.build_request.assert_not_called()
385+
386+
# Verify that send was NOT called (no streaming path)
387+
mock_async_client.send.assert_not_called()
388+
389+
# Verify that the non-streaming request method WAS called
390+
mock_async_client.request.assert_called_once_with(
391+
method="POST",
392+
url=httpx.URL("https://api.anthropic.com/v1/messages"),
393+
headers={
394+
"Authorization": "Bearer test-key"
395+
},
396+
params=None,
397+
json=request_body,
398+
)
399+
400+
# Verify response is a regular Response (not StreamingResponse)
401+
from fastapi.responses import Response, StreamingResponse
402+
assert not isinstance(response, StreamingResponse)
403+
assert isinstance(response, Response)
136404
assert response.status_code == 200

0 commit comments

Comments
 (0)