Skip to content

Commit a61275f

Browse files
Merge pull request #13796 from 0x-fang/fix_aip_pass_through_08_20
Fix(Bedrock): fix application inference profile for pass-through endpoints for bedrock
2 parents 7c2c1c1 + d2b943f commit a61275f

File tree

6 files changed

+272
-4
lines changed

6 files changed

+272
-4
lines changed

litellm/passthrough/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
2626
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
2727
from litellm.utils import client
28+
from litellm.proxy.pass_through_endpoints.common_utils import encode_bedrock_runtime_modelid_arn
2829

2930
base_llm_http_handler = BaseLLMHTTPHandler()
3031
from .utils import BasePassthroughUtils
@@ -241,6 +242,12 @@ def llm_passthrough_route(
241242
request_query_params=request_query_params,
242243
litellm_params=litellm_params_dict,
243244
)
245+
246+
# need to encode the id of application-inference-profile for bedrock
247+
if custom_llm_provider == "bedrock" and "application-inference-profile" in endpoint:
248+
encoded_url_str = encode_bedrock_runtime_modelid_arn(str(updated_url))
249+
updated_url = httpx.URL(encoded_url_str)
250+
244251
# Add or update query parameters
245252
provider_api_key = provider_config.get_api_key(api_key)
246253

litellm/proxy/pass_through_endpoints/common_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,55 @@ def get_litellm_virtual_key(request: Request) -> str:
1414
if litellm_api_key:
1515
return f"Bearer {litellm_api_key}"
1616
return request.headers.get("Authorization", "")
17+
18+
19+
def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str:
20+
"""
21+
Encodes any "/" found in the modelId of an AWS Bedrock Runtime Endpoint when arns are passed in.
22+
- modelID value can be an ARN which contains slashes that SHOULD NOT be treated as path separators.
23+
e.g endpoint: /model/<modelId>/invoke
24+
<modelId> containing arns with slashes need to be encoded from
25+
arn:aws:bedrock:ap-southeast-1:123456789012:application-inference-profile/abdefg12334 =>
26+
arn:aws:bedrock:ap-southeast-1:123456789012:application-inference-profile%2Fabdefg12334
27+
so that it is treated as one part of the path.
28+
Otherwise, the encoded endpoint will return 500 error when passed to Bedrock endpoint.
29+
30+
See the apis in https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Operations_Amazon_Bedrock_Runtime.html
31+
for more details on the regex patterns of modelId which we use in the regex logic below.
32+
33+
Args:
34+
endpoint (str): The original endpoint string which may contain ARNs that contain slashes.
35+
36+
Returns:
37+
str: The endpoint with properly encoded ARN slashes
38+
"""
39+
import re
40+
41+
# Early exit: if no ARN detected, return unchanged
42+
if 'arn:aws:' not in endpoint:
43+
return endpoint
44+
45+
# Handle all patterns in one go - more efficient and cleaner
46+
patterns = [
47+
# Custom model with 2 slashes (order matters - do this first)
48+
(r'(custom-model)/([a-z0-9.-]+)/([a-z0-9]+)', r'\1%2F\2%2F\3'),
49+
50+
# All other resource types with 1 slash
51+
(r'(:application-inference-profile)/', r'\1%2F'),
52+
(r'(:inference-profile)/', r'\1%2F'),
53+
(r'(:foundation-model)/', r'\1%2F'),
54+
(r'(:imported-model)/', r'\1%2F'),
55+
(r'(:provisioned-model)/', r'\1%2F'),
56+
(r'(:prompt)/', r'\1%2F'),
57+
(r'(:endpoint)/', r'\1%2F'),
58+
(r'(:prompt-router)/', r'\1%2F'),
59+
(r'(:default-prompt-router)/', r'\1%2F'),
60+
]
61+
62+
for pattern, replacement in patterns:
63+
# Check if pattern exists before applying regex (early exit optimization)
64+
if re.search(pattern, endpoint):
65+
endpoint = re.sub(pattern, replacement, endpoint)
66+
break # Exit after first match since each ARN has only one resource type
67+
68+
return endpoint

litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,20 +492,26 @@ async def bedrock_llm_proxy_route(
492492
data: Dict[str, Any] = {}
493493
base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
494494
try:
495-
model = endpoint.split("/")[1]
495+
endpoint_parts = endpoint.split("/")
496+
if "application-inference-profile" in endpoint:
497+
# For application-inference-profile, include the profile ID part as well
498+
model = "/".join(endpoint_parts[1:3])
499+
else:
500+
model = endpoint_parts[1]
496501
except Exception:
497502
raise HTTPException(
498503
status_code=400,
499504
detail={
500505
"error": "Model missing from endpoint. Expected format: /model/<Model>/<endpoint>. Got: "
501506
+ endpoint,
502507
},
503-
)
508+
)
504509

505510
data["method"] = request.method
506511
data["endpoint"] = endpoint
507512
data["data"] = request_body
508-
513+
data["custom_llm_provider"] = "bedrock"
514+
509515
try:
510516
result = await base_llm_response_processor.base_passthrough_process_llm_request(
511517
request=request,

tests/test_litellm/passthrough/test_passthrough_main.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
import pytest
66
from fastapi.testclient import TestClient
7-
7+
from litellm.llms.custom_httpx.http_handler import HTTPHandler
8+
from unittest.mock import MagicMock, patch
9+
import httpx
10+
811
sys.path.insert(
912
0, os.path.abspath("../../..")
1013
) # Adds the parent directory to the system path
@@ -45,3 +48,89 @@ def test_llm_passthrough_route():
4548

4649
assert response.status_code == 200
4750
assert response.json == {"message": "Hello, world!"}
51+
52+
53+
def test_bedrock_application_inference_profile_url_encoding():
54+
client = HTTPHandler()
55+
56+
mock_provider_config = MagicMock()
57+
mock_provider_config.get_complete_url.return_value = (
58+
httpx.URL("https://bedrock-runtime.us-east-1.amazonaws.com/model/arn:aws:bedrock:us-east-1:123456789123:application-inference-profile/r742sbn2zckd/converse"),
59+
"https://bedrock-runtime.us-east-1.amazonaws.com"
60+
)
61+
mock_provider_config.get_api_key.return_value = "test-key"
62+
mock_provider_config.validate_environment.return_value = {}
63+
mock_provider_config.sign_request.return_value = ({}, None)
64+
mock_provider_config.is_streaming_request.return_value = False
65+
66+
with patch("litellm.utils.ProviderConfigManager.get_provider_passthrough_config", return_value=mock_provider_config), \
67+
patch("litellm.litellm_core_utils.get_litellm_params.get_litellm_params", return_value={}), \
68+
patch("litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider", return_value=("test-model", "bedrock", "test-key", "test-base")), \
69+
patch.object(client.client, "send", return_value=MagicMock(status_code=200)) as mock_send, \
70+
patch.object(client.client, "build_request") as mock_build_request:
71+
72+
# Mock logging object
73+
mock_logging_obj = MagicMock()
74+
mock_logging_obj.update_environment_variables = MagicMock()
75+
76+
response = llm_passthrough_route(
77+
model="arn:aws:bedrock:us-east-1:123456789123:application-inference-profile/r742sbn2zckd",
78+
endpoint="model/arn:aws:bedrock:us-east-1:123456789123:application-inference-profile/r742sbn2zckd/converse",
79+
method="POST",
80+
custom_llm_provider="bedrock",
81+
client=client,
82+
litellm_logging_obj=mock_logging_obj,
83+
)
84+
85+
# Verify that build_request was called with the encoded URL
86+
mock_build_request.assert_called_once()
87+
call_args = mock_build_request.call_args
88+
89+
# The URL should have the application-inference-profile ID encoded
90+
actual_url = str(call_args.kwargs["url"])
91+
assert "application-inference-profile%2Fr742sbn2zckd" in actual_url
92+
assert response.status_code == 200
93+
94+
95+
def test_bedrock_non_application_inference_profile_no_encoding():
96+
client = HTTPHandler()
97+
98+
# Mock the provider config and its methods
99+
mock_provider_config = MagicMock()
100+
mock_provider_config.get_complete_url.return_value = (
101+
httpx.URL("https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/converse"),
102+
"https://bedrock-runtime.us-east-1.amazonaws.com"
103+
)
104+
mock_provider_config.get_api_key.return_value = "test-key"
105+
mock_provider_config.validate_environment.return_value = {}
106+
mock_provider_config.sign_request.return_value = ({}, None)
107+
mock_provider_config.is_streaming_request.return_value = False
108+
109+
with patch("litellm.utils.ProviderConfigManager.get_provider_passthrough_config", return_value=mock_provider_config), \
110+
patch("litellm.litellm_core_utils.get_litellm_params.get_litellm_params", return_value={}), \
111+
patch("litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider", return_value=("test-model", "bedrock", "test-key", "test-base")), \
112+
patch.object(client.client, "send", return_value=MagicMock(status_code=200)) as mock_send, \
113+
patch.object(client.client, "build_request") as mock_build_request:
114+
115+
# Mock logging object
116+
mock_logging_obj = MagicMock()
117+
mock_logging_obj.update_environment_variables = MagicMock()
118+
119+
response = llm_passthrough_route(
120+
model="anthropic.claude-3-sonnet-20240229-v1:0",
121+
endpoint="model/anthropic.claude-3-sonnet-20240229-v1:0/converse",
122+
method="POST",
123+
custom_llm_provider="bedrock",
124+
client=client,
125+
litellm_logging_obj=mock_logging_obj,
126+
)
127+
128+
# Verify that build_request was called with the original URL (no encoding)
129+
mock_build_request.assert_called_once()
130+
call_args = mock_build_request.call_args
131+
132+
# The URL should NOT have application-inference-profile encoding
133+
actual_url = str(call_args.kwargs["url"])
134+
assert "application-inference-profile%2F" not in actual_url
135+
assert "anthropic.claude-3-sonnet-20240229-v1:0" in actual_url
136+
assert response.status_code == 200

tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
create_pass_through_route,
2222
vertex_discovery_proxy_route,
2323
vertex_proxy_route,
24+
bedrock_llm_proxy_route,
2425
)
2526
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
2627

@@ -853,3 +854,63 @@ async def test_is_streaming_request_fn():
853854
mock_request.headers = {"content-type": "multipart/form-data"}
854855
mock_request.form = AsyncMock(return_value={"stream": "true"})
855856
assert await is_streaming_request_fn(mock_request) is True
857+
858+
class TestBedrockLLMProxyRoute:
859+
@pytest.mark.asyncio
860+
async def test_bedrock_llm_proxy_route_application_inference_profile(self):
861+
mock_request = Mock()
862+
mock_request.method = "POST"
863+
mock_response = Mock()
864+
mock_user_api_key_dict = Mock()
865+
mock_request_body = {"messages": [{"role": "user", "content": "test"}]}
866+
mock_processor = Mock()
867+
mock_processor.base_passthrough_process_llm_request = AsyncMock(return_value="success")
868+
869+
with patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints._read_request_body", return_value=mock_request_body), \
870+
patch("litellm.proxy.common_request_processing.ProxyBaseLLMRequestProcessing", return_value=mock_processor):
871+
872+
# Test application-inference-profile endpoint
873+
endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/r742sbn2zckd/converse"
874+
875+
result = await bedrock_llm_proxy_route(
876+
endpoint=endpoint,
877+
request=mock_request,
878+
fastapi_response=mock_response,
879+
user_api_key_dict=mock_user_api_key_dict,
880+
)
881+
882+
mock_processor.base_passthrough_process_llm_request.assert_called_once()
883+
call_kwargs = mock_processor.base_passthrough_process_llm_request.call_args.kwargs
884+
885+
# For application-inference-profile, model should be "arn:aws:bedrock:us-east-1:026090525607:application-inference-profile/r742sbn2zckd"
886+
assert call_kwargs["model"] == "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/r742sbn2zckd"
887+
assert result == "success"
888+
889+
@pytest.mark.asyncio
890+
async def test_bedrock_llm_proxy_route_regular_model(self):
891+
mock_request = Mock()
892+
mock_request.method = "POST"
893+
mock_response = Mock()
894+
mock_user_api_key_dict = Mock()
895+
mock_request_body = {"messages": [{"role": "user", "content": "test"}]}
896+
mock_processor = Mock()
897+
mock_processor.base_passthrough_process_llm_request = AsyncMock(return_value="success")
898+
899+
with patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints._read_request_body", return_value=mock_request_body), \
900+
patch("litellm.proxy.common_request_processing.ProxyBaseLLMRequestProcessing", return_value=mock_processor):
901+
902+
# Test regular model endpoint
903+
endpoint = "model/anthropic.claude-3-sonnet-20240229-v1:0/converse"
904+
905+
result = await bedrock_llm_proxy_route(
906+
endpoint=endpoint,
907+
request=mock_request,
908+
fastapi_response=mock_response,
909+
user_api_key_dict=mock_user_api_key_dict,
910+
)
911+
mock_processor.base_passthrough_process_llm_request.assert_called_once()
912+
call_kwargs = mock_processor.base_passthrough_process_llm_request.call_args.kwargs
913+
914+
# For regular models, model should be just the model ID
915+
assert call_kwargs["model"] == "anthropic.claude-3-sonnet-20240229-v1:0"
916+
assert result == "success"

tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_endpoints_common_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
from fastapi import Request, Response
1111
from fastapi.testclient import TestClient
12+
from litellm.proxy.pass_through_endpoints.common_utils import encode_bedrock_runtime_modelid_arn
1213

1314
sys.path.insert(
1415
0, os.path.abspath("../../../..")
@@ -42,3 +43,55 @@ async def test_get_litellm_virtual_key():
4243
}
4344
result = get_litellm_virtual_key(mock_request)
4445
assert result == "Bearer test-key-123"
46+
47+
def test_encode_bedrock_runtime_modelid_arn():
48+
# Test application-inference-profile ARN
49+
endpoint = "model/arn:aws:bedrock:us-east-1:123456789123:application-inference-profile/r742sbn2zckd/converse"
50+
expected = "model/arn:aws:bedrock:us-east-1:123456789123:application-inference-profile%2Fr742sbn2zckd/converse"
51+
result = encode_bedrock_runtime_modelid_arn(endpoint)
52+
assert result == expected
53+
54+
# Test inference-profile ARN
55+
endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-profile/invoke"
56+
expected = "model/arn:aws:bedrock:us-east-1:123456789012:inference-profile%2Ftest-profile/invoke"
57+
result = encode_bedrock_runtime_modelid_arn(endpoint)
58+
assert result == expected
59+
60+
# Test foundation-model ARN
61+
endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:foundation-model/anthropic.claude-3/converse"
62+
expected = "model/arn:aws:bedrock:us-east-1:123456789012:foundation-model%2Fanthropic.claude-3/converse"
63+
result = encode_bedrock_runtime_modelid_arn(endpoint)
64+
assert result == expected
65+
66+
# Test custom-model ARN (2 slashes)
67+
endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:custom-model/my-model.fine-tuned/abc123/invoke"
68+
expected = "model/arn:aws:bedrock:us-east-1:123456789012:custom-model%2Fmy-model.fine-tuned%2Fabc123/invoke"
69+
result = encode_bedrock_runtime_modelid_arn(endpoint)
70+
assert result == expected
71+
72+
# Test provisioned-model ARN
73+
endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:provisioned-model/test-model/converse"
74+
expected = "model/arn:aws:bedrock:us-east-1:123456789012:provisioned-model%2Ftest-model/converse"
75+
result = encode_bedrock_runtime_modelid_arn(endpoint)
76+
assert result == expected
77+
78+
79+
def test_encode_bedrock_runtime_modelid_arn_no_arn():
80+
# Test regular model ID (no ARN)
81+
endpoint = "model/anthropic.claude-3-sonnet-20240229-v1:0/converse"
82+
result = encode_bedrock_runtime_modelid_arn(endpoint)
83+
assert result == endpoint
84+
85+
86+
def test_encode_bedrock_runtime_modelid_arn_edge_cases():
87+
# Test multiple ARN types (should only encode first match)
88+
endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test1/converse"
89+
expected = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile%2Ftest1/converse"
90+
result = encode_bedrock_runtime_modelid_arn(endpoint)
91+
assert result == expected
92+
93+
# Test ARN with special characters in resource ID
94+
endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile.v1/invoke"
95+
expected = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile%2Ftest-profile.v1/invoke"
96+
result = encode_bedrock_runtime_modelid_arn(endpoint)
97+
assert result == expected

0 commit comments

Comments
 (0)