Skip to content

Commit 433d1a4

Browse files
authored
[Bug fix] - Fix /messages fallback from Anthropic API -> Bedrock API (#13946)
* use helper get_provider_specific_headers * fix get_provider_specific_headers * test_anthropic_messages_fallbacks * bedrock/us.anthropic.claude-sonnet-4 * fix: get_provider_specific_headers * TestProviderSpecificHeaderUtils * test_anthropic_messages_fallbacks
1 parent c1ee8c2 commit 433d1a4

File tree

6 files changed

+148
-18
lines changed

6 files changed

+148
-18
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Dict, Optional
2+
3+
from litellm.types.utils import ProviderSpecificHeader
4+
5+
6+
class ProviderSpecificHeaderUtils:
7+
@staticmethod
8+
def get_provider_specific_headers(
9+
provider_specific_header: Optional[ProviderSpecificHeader],
10+
custom_llm_provider: Optional[str],
11+
) -> Dict:
12+
"""
13+
Get the provider specific headers for the given custom llm provider
14+
15+
Returns:
16+
Optional[Dict]: The provider specific headers for the given custom llm provider
17+
"""
18+
if (
19+
provider_specific_header is not None
20+
and provider_specific_header.get("custom_llm_provider") == custom_llm_provider
21+
):
22+
return provider_specific_header.get("extra_headers", {})
23+
return {}

litellm/llms/custom_httpx/llm_http_handler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,10 @@ async def async_anthropic_messages_handler(
12571257
stream: Optional[bool] = False,
12581258
kwargs: Optional[Dict[str, Any]] = None,
12591259
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
1260+
from litellm.litellm_core_utils.get_provider_specific_headers import (
1261+
ProviderSpecificHeaderUtils,
1262+
)
1263+
12601264
if client is None or not isinstance(client, AsyncHTTPHandler):
12611265
async_httpx_client = get_async_httpx_client(
12621266
llm_provider=litellm.LlmProviders.ANTHROPIC
@@ -1270,10 +1274,9 @@ async def async_anthropic_messages_handler(
12701274
Optional[litellm.types.utils.ProviderSpecificHeader],
12711275
kwargs.get("provider_specific_header", None),
12721276
)
1273-
extra_headers = (
1274-
provider_specific_header.get("extra_headers", {})
1275-
if provider_specific_header
1276-
else {}
1277+
extra_headers = ProviderSpecificHeaderUtils.get_provider_specific_headers(
1278+
provider_specific_header=provider_specific_header,
1279+
custom_llm_provider=custom_llm_provider,
12771280
)
12781281
(
12791282
headers,

litellm/main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@
6161
from litellm.integrations.custom_logger import CustomLogger
6262
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
6363
from litellm.litellm_core_utils.dd_tracing import tracer
64+
from litellm.litellm_core_utils.get_provider_specific_headers import (
65+
ProviderSpecificHeaderUtils,
66+
)
6467
from litellm.litellm_core_utils.health_check_utils import (
6568
_create_health_check_response,
6669
_filter_model_params,
@@ -1107,11 +1110,11 @@ def completion( # type: ignore # noqa: PLR0915
11071110
api_key=api_key,
11081111
)
11091112

1110-
if (
1111-
provider_specific_header is not None
1112-
and provider_specific_header["custom_llm_provider"] == custom_llm_provider
1113-
):
1114-
headers.update(provider_specific_header["extra_headers"])
1113+
if provider_specific_header is not None:
1114+
headers.update(ProviderSpecificHeaderUtils.get_provider_specific_headers(
1115+
provider_specific_header=provider_specific_header,
1116+
custom_llm_provider=custom_llm_provider,
1117+
))
11151118

11161119
if model_response is not None and hasattr(model_response, "_hidden_params"):
11171120
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider

litellm/proxy/proxy_config.yaml

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
model_list:
2-
- model_name: bedrock/converse/us.anthropic.claude-sonnet-4-20250514-v1:0
2+
- model_name: anthropic/*
33
litellm_params:
4-
model: bedrock/converse/us.anthropic.claude-sonnet-4-20250514-v1:0
4+
model: anthropic/*
5+
api_key: os.environ/OPENAI_API_KEY_IJ
56
- model_name: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0
67
litellm_params:
78
model: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0
9+
- model_name: bedrock/converse/us.anthropic.claude-sonnet-4-20250514-v1:0
10+
litellm_params:
11+
model: bedrock/converse/us.anthropic.claude-sonnet-4-20250514-v1:0
12+
13+
router_settings:
14+
fallbacks: [
15+
{"anthropic/claude-opus-4-20250514":
16+
{
17+
"bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0"
18+
}
19+
}
20+
]
821
litellm_settings:
922
callbacks: ["datadog_llm_observability"]
1023

11-
guardrails:
12-
- guardrail_name: "bedrock-pre-guard"
13-
litellm_params:
14-
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
15-
mode: "during_call"
16-
guardrailIdentifier: ff6ujrregl1q
17-
guardrailVersion: "DRAFT"

tests/pass_through_unit_tests/test_anthropic_messages_passthrough.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,58 @@ async def test_anthropic_messages_litellm_router_routing_strategy():
273273
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
274274
return response
275275

276+
@pytest.mark.asyncio
277+
async def test_anthropic_messages_fallbacks():
278+
"""
279+
E2E test the anthropic_messages fallbacks from Anthropic API to Bedrock
280+
"""
281+
litellm._turn_on_debug()
282+
router = Router(
283+
model_list=[
284+
{
285+
"model_name": "anthropic/claude-opus-4-20250514",
286+
"litellm_params": {
287+
"model": "anthropic/claude-opus-4-20250514",
288+
"api_key": "bad-key",
289+
},
290+
},
291+
{
292+
"model_name": "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
293+
"litellm_params": {
294+
"model": "bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0",
295+
},
296+
}
297+
],
298+
fallbacks=[
299+
{
300+
"anthropic/claude-opus-4-20250514":
301+
["bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0"]
302+
}
303+
]
304+
)
305+
306+
# Set up test parameters
307+
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
308+
309+
# Call the handler
310+
response = await router.aanthropic_messages(
311+
messages=messages,
312+
model="anthropic/claude-opus-4-20250514",
313+
max_tokens=100,
314+
metadata={
315+
"user_id": "hello",
316+
},
317+
)
318+
319+
# Verify response
320+
assert "id" in response
321+
assert "content" in response
322+
assert "model" in response
323+
assert response["role"] == "assistant"
324+
325+
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
326+
return response
327+
276328

277329
@pytest.mark.asyncio
278330
async def test_anthropic_messages_litellm_router_latency_metadata_tracking():
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pytest
2+
3+
from litellm.litellm_core_utils.get_provider_specific_headers import (
4+
ProviderSpecificHeaderUtils,
5+
)
6+
from litellm.types.utils import ProviderSpecificHeader
7+
8+
9+
class TestProviderSpecificHeaderUtils:
10+
def test_get_provider_specific_headers_matching_provider(self):
11+
"""Test that the method returns extra_headers when custom_llm_provider matches."""
12+
provider_specific_header: ProviderSpecificHeader = {
13+
"custom_llm_provider": "openai",
14+
"extra_headers": {"Authorization": "Bearer token123", "Custom-Header": "value"}
15+
}
16+
custom_llm_provider = "openai"
17+
18+
result = ProviderSpecificHeaderUtils.get_provider_specific_headers(
19+
provider_specific_header, custom_llm_provider
20+
)
21+
22+
expected = {"Authorization": "Bearer token123", "Custom-Header": "value"}
23+
assert result == expected
24+
25+
def test_get_provider_specific_headers_no_match_or_none(self):
26+
"""Test that the method returns empty dict when provider doesn't match or is None."""
27+
# Test case 1: Provider doesn't match
28+
provider_specific_header: ProviderSpecificHeader = {
29+
"custom_llm_provider": "anthropic",
30+
"extra_headers": {"Authorization": "Bearer token123"}
31+
}
32+
custom_llm_provider = "openai"
33+
34+
result = ProviderSpecificHeaderUtils.get_provider_specific_headers(
35+
provider_specific_header, custom_llm_provider
36+
)
37+
assert result == {}
38+
39+
# Test case 2: provider_specific_header is None
40+
result = ProviderSpecificHeaderUtils.get_provider_specific_headers(
41+
None, "openai"
42+
)
43+
assert result == {}

0 commit comments

Comments
 (0)