11import hashlib
22import json
33import os
4+ import urllib .parse
45from datetime import datetime
56from typing import (
67 TYPE_CHECKING ,
@@ -331,16 +332,61 @@ def get_bedrock_invoke_provider(
331332 return provider
332333 return None
333334
335+ @staticmethod
336+ def get_bedrock_model_id (
337+ optional_params : dict ,
338+ provider : Optional [BEDROCK_INVOKE_PROVIDERS_LITERAL ],
339+ model : str ,
340+ ) -> str :
341+ model_id = optional_params .pop ("model_id" , None )
342+ if model_id is not None :
343+ model_id = BaseAWSLLM .encode_model_id (model_id = model_id )
344+ else :
345+ model_id = model
346+
347+ model_id = model_id .replace ("invoke/" , "" , 1 )
348+ if provider == "llama" and "llama/" in model_id :
349+ model_id = BaseAWSLLM ._get_model_id_from_model_with_spec (
350+ model_id , spec = "llama"
351+ )
352+ elif provider == "deepseek_r1" and "deepseek_r1/" in model_id :
353+ model_id = BaseAWSLLM ._get_model_id_from_model_with_spec (
354+ model_id , spec = "deepseek_r1"
355+ )
356+ return model_id
357+
358+ @staticmethod
359+ def _get_model_id_from_model_with_spec (
360+ model : str ,
361+ spec : str ,
362+ ) -> str :
363+ """
364+ Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
365+ """
366+ model_id = model .replace (spec + "/" , "" )
367+ return BaseAWSLLM .encode_model_id (model_id = model_id )
368+
369+ @staticmethod
370+ def encode_model_id (model_id : str ) -> str :
371+ """
372+ Double encode the model ID to ensure it matches the expected double-encoded format.
373+ Args:
374+ model_id (str): The model ID to encode.
375+ Returns:
376+ str: The double-encoded model ID.
377+ """
378+ return urllib .parse .quote (model_id , safe = "" )
379+
334380 @staticmethod
335381 def get_bedrock_embedding_provider (
336382 model : str ,
337383 ) -> Optional [BEDROCK_EMBEDDING_PROVIDERS_LITERAL ]:
338384 """
339385 Helper function to get the bedrock embedding provider from the model
340-
386+
341387 Handles scenarios like:
342388 1. model=cohere.embed-english-v3:0 -> Returns `cohere`
343- 2. model=amazon.titan-embed-text-v1 -> Returns `amazon`
389+ 2. model=amazon.titan-embed-text-v1 -> Returns `amazon`
344390 3. model=us.twelvelabs.marengo-embed-2-7-v1:0 -> Returns `twelvelabs`
345391 4. model=twelvelabs.marengo-embed-2-7-v1:0 -> Returns `twelvelabs`
346392 """
@@ -349,20 +395,24 @@ def get_bedrock_embedding_provider(
349395 parts = model .split ("." )
350396 # Check if the second part (after potential region) is a known provider
351397 if len (parts ) >= 2 :
352- potential_provider = parts [1 ] # e.g., "twelvelabs" from "us.twelvelabs.marengo-embed-2-7-v1:0"
398+ potential_provider = parts [
399+ 1
400+ ] # e.g., "twelvelabs" from "us.twelvelabs.marengo-embed-2-7-v1:0"
353401 if potential_provider in get_args (BEDROCK_EMBEDDING_PROVIDERS_LITERAL ):
354402 return cast (BEDROCK_EMBEDDING_PROVIDERS_LITERAL , potential_provider )
355-
403+
356404 # Check if the first part is a known provider (standard format)
357- potential_provider = parts [0 ] # e.g., "cohere" from "cohere.embed-english-v3:0"
405+ potential_provider = parts [
406+ 0
407+ ] # e.g., "cohere" from "cohere.embed-english-v3:0"
358408 if potential_provider in get_args (BEDROCK_EMBEDDING_PROVIDERS_LITERAL ):
359409 return cast (BEDROCK_EMBEDDING_PROVIDERS_LITERAL , potential_provider )
360-
410+
361411 # Fallback: check if any provider name appears in the model string
362412 for provider in get_args (BEDROCK_EMBEDDING_PROVIDERS_LITERAL ):
363413 if provider in model :
364414 return cast (BEDROCK_EMBEDDING_PROVIDERS_LITERAL , provider )
365-
415+
366416 return None
367417
368418 def _get_aws_region_name (
@@ -984,20 +1034,23 @@ def get_request_headers(
9841034 raise ImportError (
9851035 "Missing boto3 to call bedrock. Run 'pip install boto3'."
9861036 )
987-
1037+
9881038 # Filter headers for AWS signature calculation
9891039 # AWS SigV4 only includes specific headers in signature calculation
9901040 aws_signature_headers = self ._filter_headers_for_aws_signature (headers )
9911041 sigv4 = SigV4Auth (credentials , "bedrock" , aws_region_name )
9921042 request = AWSRequest (
993- method = "POST" , url = endpoint_url , data = data , headers = aws_signature_headers
1043+ method = "POST" ,
1044+ url = endpoint_url ,
1045+ data = data ,
1046+ headers = aws_signature_headers ,
9941047 )
9951048 sigv4 .add_auth (request )
996-
1049+
9971050 # Add back all original headers (including forwarded ones) after signature calculation
9981051 for header_name , header_value in headers .items ():
9991052 request .headers [header_name ] = header_value
1000-
1053+
10011054 if (
10021055 extra_headers is not None and "Authorization" in extra_headers
10031056 ): # prevent sigv4 from overwriting the auth header
@@ -1013,16 +1066,27 @@ def _filter_headers_for_aws_signature(self, headers: dict) -> dict:
10131066 """
10141067 aws_signature_headers = {}
10151068 aws_headers = {
1016- 'host' , 'content-type' , 'date' , 'x-amz-date' , 'x-amz-security-token' ,
1017- 'x-amz-content-sha256' , 'x-amz-algorithm' , 'x-amz-credential' ,
1018- 'x-amz-signedheaders' , 'x-amz-signature'
1069+ "host" ,
1070+ "content-type" ,
1071+ "date" ,
1072+ "x-amz-date" ,
1073+ "x-amz-security-token" ,
1074+ "x-amz-content-sha256" ,
1075+ "x-amz-algorithm" ,
1076+ "x-amz-credential" ,
1077+ "x-amz-signedheaders" ,
1078+ "x-amz-signature" ,
10191079 }
1020-
1080+
10211081 for header_name , header_value in headers .items ():
10221082 header_lower = header_name .lower ()
1023- if header_lower in aws_headers or header_lower .startswith ('x-amz-' ) or header_lower .startswith ('x-amzn-' ):
1083+ if (
1084+ header_lower in aws_headers
1085+ or header_lower .startswith ("x-amz-" )
1086+ or header_lower .startswith ("x-amzn-" )
1087+ ):
10241088 aws_signature_headers [header_name ] = header_value
1025-
1089+
10261090 return aws_signature_headers
10271091
10281092 def _sign_request (
0 commit comments