Skip to content

Commit 046b7ef

Browse files
authored
Make Bedrock image generation more consistent (#17021)
1 parent cfd35d3 commit 046b7ef

File tree

8 files changed

+164
-124
lines changed

8 files changed

+164
-124
lines changed

litellm/llms/bedrock/image/amazon_nova_canvas_transformation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from openai.types.image import Image
55

6+
from litellm import get_model_info
67
from litellm.types.llms.bedrock import (
78
AmazonNovaCanvasColorGuidedGenerationParams,
89
AmazonNovaCanvasColorGuidedRequest,
@@ -197,3 +198,22 @@ def transform_response_dict_to_openai_response(
197198

198199
model_response.data = openai_images
199200
return model_response
201+
202+
@classmethod
203+
def cost_calculator(
204+
cls,
205+
model: str,
206+
image_response: ImageResponse,
207+
size: Optional[str] = None,
208+
optional_params: Optional[dict] = None,
209+
) -> float:
210+
model_info = get_model_info(
211+
model=model,
212+
custom_llm_provider="bedrock",
213+
)
214+
215+
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
216+
num_images: int = 0
217+
if image_response.data:
218+
num_images = len(image_response.data)
219+
return output_cost_per_image * num_images

litellm/llms/bedrock/image/amazon_stability1_transformation.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import copy
2+
import os
13
import types
24
from typing import List, Optional
35

46
from openai.types.image import Image
57

8+
from litellm import get_model_info
69
from litellm.types.utils import ImageResponse
710

811

@@ -90,6 +93,31 @@ def map_openai_params(
9093

9194
return optional_params
9295

96+
@classmethod
97+
def transform_request_body(
98+
cls,
99+
text: str,
100+
optional_params: dict,
101+
) -> dict:
102+
inference_params = copy.deepcopy(optional_params)
103+
inference_params.pop(
104+
"user", None
105+
) # make sure user is not passed in for bedrock call
106+
107+
prompt = text.replace(os.linesep, " ")
108+
## LOAD CONFIG
109+
config = cls.get_config()
110+
for k, v in config.items():
111+
if (
112+
k not in inference_params
113+
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
114+
inference_params[k] = v
115+
116+
return {
117+
"text_prompts": [{"text": prompt, "weight": 1}],
118+
**inference_params,
119+
}
120+
93121
@classmethod
94122
def transform_response_dict_to_openai_response(
95123
cls, model_response: ImageResponse, response_dict: dict
@@ -102,3 +130,34 @@ def transform_response_dict_to_openai_response(
102130
model_response.data = image_list
103131

104132
return model_response
133+
134+
@classmethod
135+
def cost_calculator(
136+
cls,
137+
model: str,
138+
image_response: ImageResponse,
139+
size: Optional[str] = None,
140+
optional_params: Optional[dict] = None,
141+
) -> float:
142+
optional_params = optional_params or {}
143+
144+
# see model_prices_and_context_window.json for details on how steps is used
145+
# Reference pricing by steps for stability 1: https://aws.amazon.com/bedrock/pricing/
146+
_steps = optional_params.get("steps", 50)
147+
steps = "max-steps" if _steps > 50 else "50-steps"
148+
149+
# size is stored in model_prices_and_context_window.json as 1024-x-1024
150+
# current size has 1024x1024
151+
size = size or "1024-x-1024"
152+
model = f"{size}/{steps}/{model}"
153+
154+
model_info = get_model_info(
155+
model=model,
156+
custom_llm_provider="bedrock",
157+
)
158+
159+
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
160+
num_images: int = 0
161+
if image_response.data:
162+
num_images = len(image_response.data)
163+
return output_cost_per_image * num_images

litellm/llms/bedrock/image/amazon_stability3_transformation.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from openai.types.image import Image
55

6+
from litellm import get_model_info
7+
from litellm.llms.bedrock.common_utils import BedrockError
68
from litellm.types.llms.bedrock import (
79
AmazonStability3TextToImageRequest,
810
AmazonStability3TextToImageResponse,
@@ -66,12 +68,12 @@ def _is_stability_3_model(cls, model: Optional[str] = None) -> bool:
6668

6769
@classmethod
6870
def transform_request_body(
69-
cls, prompt: str, optional_params: dict
71+
cls, text: str, optional_params: dict
7072
) -> AmazonStability3TextToImageRequest:
7173
"""
7274
Transform the request body for the Stability 3 models
7375
"""
74-
data = AmazonStability3TextToImageRequest(prompt=prompt, **optional_params)
76+
data = AmazonStability3TextToImageRequest(prompt=text, **optional_params)
7577
return data
7678

7779
@classmethod
@@ -92,9 +94,34 @@ def transform_response_dict_to_openai_response(
9294
"""
9395

9496
stability_3_response = AmazonStability3TextToImageResponse(**response_dict)
97+
98+
finish_reasons = stability_3_response.get("finish_reasons", [])
99+
finish_reasons = [reason for reason in finish_reasons if reason]
100+
if len(finish_reasons) > 0:
101+
raise BedrockError(status_code=400, message="; ".join(finish_reasons))
102+
95103
openai_images: List[Image] = []
96104
for _img in stability_3_response.get("images", []):
97105
openai_images.append(Image(b64_json=_img))
98106

99107
model_response.data = openai_images
100108
return model_response
109+
110+
@classmethod
111+
def cost_calculator(
112+
cls,
113+
model: str,
114+
image_response: ImageResponse,
115+
size: Optional[str] = None,
116+
optional_params: Optional[dict] = None,
117+
) -> float:
118+
model_info = get_model_info(
119+
model=model,
120+
custom_llm_provider="bedrock",
121+
)
122+
123+
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
124+
num_images: int = 0
125+
if image_response.data:
126+
num_images = len(image_response.data)
127+
return output_cost_per_image * num_images

litellm/llms/bedrock/image/amazon_titan_transformation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,16 @@ def map_openai_params(
103103
return optional_params
104104

105105
@classmethod
106-
def _transform_request(
106+
def transform_request_body(
107107
cls,
108-
input: str,
108+
text: str,
109109
optional_params: dict,
110110
) -> AmazonTitanImageGenerationRequestBody:
111111
from typing import Any, Dict
112112

113113
image_generation_config = optional_params.pop("imageGenerationConfig", {})
114114
negative_text = optional_params.pop("negativeText", None)
115-
text_to_image_params: Dict[str, Any] = {"text": input}
115+
text_to_image_params: Dict[str, Any] = {"text": text}
116116
if negative_text:
117117
text_to_image_params["negativeText"] = negative_text
118118
task_type = optional_params.pop("taskType", "TEXT_IMAGE")
Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from typing import Optional
22

3-
import litellm
4-
from litellm.llms.bedrock.image.amazon_titan_transformation import (
5-
AmazonTitanImageGenerationConfig,
6-
)
3+
from litellm.llms.bedrock.image.image_handler import BedrockImageGeneration
74
from litellm.types.utils import ImageResponse
85

96

@@ -18,36 +15,10 @@ def cost_calculator(
1815
1916
Handles both Stability 1 and Stability 3 models
2017
"""
21-
if litellm.AmazonStability3Config()._is_stability_3_model(model=model):
22-
pass
23-
elif AmazonTitanImageGenerationConfig._is_titan_model(model=model):
24-
return AmazonTitanImageGenerationConfig.cost_calculator(
25-
model=model,
26-
image_response=image_response,
27-
size=size,
28-
optional_params=optional_params,
29-
)
30-
else:
31-
# Stability 1 models
32-
optional_params = optional_params or {}
33-
34-
# see model_prices_and_context_window.json for details on how steps is used
35-
# Reference pricing by steps for stability 1: https://aws.amazon.com/bedrock/pricing/
36-
_steps = optional_params.get("steps", 50)
37-
steps = "max-steps" if _steps > 50 else "50-steps"
38-
39-
# size is stored in model_prices_and_context_window.json as 1024-x-1024
40-
# current size has 1024x1024
41-
size = size or "1024-x-1024"
42-
model = f"{size}/{steps}/{model}"
43-
44-
_model_info = litellm.get_model_info(
18+
config_class = BedrockImageGeneration.get_config_class(model=model)
19+
return config_class.cost_calculator(
4520
model=model,
46-
custom_llm_provider="bedrock",
21+
image_response=image_response,
22+
size=size,
23+
optional_params=optional_params,
4724
)
48-
49-
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
50-
num_images: int = 0
51-
if image_response.data:
52-
num_images = len(image_response.data)
53-
return output_cost_per_image * num_images

litellm/llms/bedrock/image/image_handler.py

Lines changed: 23 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import copy
21
import json
3-
import os
42
from typing import TYPE_CHECKING, Any, Optional, Union
53

64
import httpx
75
from pydantic import BaseModel
86

97
import litellm
10-
from litellm import BEDROCK_INVOKE_PROVIDERS_LITERAL
118
from litellm._logging import verbose_logger
129
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
1310
from litellm.llms.bedrock.image.amazon_nova_canvas_transformation import (
@@ -47,11 +44,30 @@ class BedrockImagePreparedRequest(BaseModel):
4744
data: dict
4845

4946

47+
BedrockImageConfigClass = Union[
48+
type[AmazonTitanImageGenerationConfig],
49+
type[AmazonNovaCanvasConfig],
50+
type[AmazonStability3Config],
51+
type[litellm.AmazonStabilityConfig],
52+
]
53+
54+
5055
class BedrockImageGeneration(BaseAWSLLM):
5156
"""
5257
Bedrock Image Generation handler
5358
"""
5459

60+
@classmethod
61+
def get_config_class(cls, model: str | None) -> BedrockImageConfigClass:
62+
if AmazonTitanImageGenerationConfig._is_titan_model(model):
63+
return AmazonTitanImageGenerationConfig
64+
elif AmazonNovaCanvasConfig._is_nova_model(model):
65+
return AmazonNovaCanvasConfig
66+
elif AmazonStability3Config._is_stability_3_model(model):
67+
return AmazonStability3Config
68+
else:
69+
return litellm.AmazonStabilityConfig
70+
5571
def image_generation(
5672
self,
5773
model: str,
@@ -202,7 +218,6 @@ def _prepare_request(
202218
model=model,
203219
prompt=prompt,
204220
optional_params=optional_params,
205-
bedrock_provider=bedrock_provider,
206221
)
207222

208223
# Make POST Request
@@ -241,7 +256,6 @@ def _prepare_request(
241256
def _get_request_body(
242257
self,
243258
model: str,
244-
bedrock_provider: Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL],
245259
prompt: str,
246260
optional_params: dict,
247261
) -> dict:
@@ -253,49 +267,9 @@ def _get_request_body(
253267
Returns:
254268
dict: The request body to use for the Bedrock Image Generation API
255269
"""
256-
if bedrock_provider == "amazon" or bedrock_provider == "nova":
257-
# Handle Amazon Nova Canvas models
258-
provider = "amazon"
259-
elif bedrock_provider == "stability":
260-
provider = "stability"
261-
else:
262-
# Fallback to original logic for backward compatibility
263-
provider = model.split(".")[0]
264-
inference_params = copy.deepcopy(optional_params)
265-
inference_params.pop(
266-
"user", None
267-
) # make sure user is not passed in for bedrock call
268-
data = {}
269-
if provider == "stability":
270-
if litellm.AmazonStability3Config._is_stability_3_model(model):
271-
request_body = litellm.AmazonStability3Config.transform_request_body(
272-
prompt=prompt, optional_params=optional_params
273-
)
274-
return dict(request_body)
275-
else:
276-
prompt = prompt.replace(os.linesep, " ")
277-
## LOAD CONFIG
278-
config = litellm.AmazonStabilityConfig.get_config()
279-
for k, v in config.items():
280-
if (
281-
k not in inference_params
282-
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
283-
inference_params[k] = v
284-
data = {
285-
"text_prompts": [{"text": prompt, "weight": 1}],
286-
**inference_params,
287-
}
288-
elif provider == "amazon":
289-
return dict(
290-
litellm.AmazonNovaCanvasConfig.transform_request_body(
291-
text=prompt, optional_params=optional_params
292-
)
293-
)
294-
else:
295-
raise BedrockError(
296-
status_code=422, message=f"Unsupported model={model}, passed in"
297-
)
298-
return data
270+
config_class = self.get_config_class(model=model)
271+
request_body = config_class.transform_request_body(text=prompt, optional_params=optional_params)
272+
return dict(request_body)
299273

300274
def _transform_response_dict_to_openai_response(
301275
self,
@@ -323,20 +297,7 @@ def _transform_response_dict_to_openai_response(
323297
if response_dict is None:
324298
raise ValueError("Error in response object format, got None")
325299

326-
config_class: Union[
327-
type[AmazonTitanImageGenerationConfig],
328-
type[AmazonNovaCanvasConfig],
329-
type[AmazonStability3Config],
330-
type[litellm.AmazonStabilityConfig],
331-
]
332-
if AmazonTitanImageGenerationConfig._is_titan_model(model=model):
333-
config_class = AmazonTitanImageGenerationConfig
334-
elif AmazonNovaCanvasConfig._is_nova_model(model=model):
335-
config_class = AmazonNovaCanvasConfig
336-
elif AmazonStability3Config._is_stability_3_model(model=model):
337-
config_class = AmazonStability3Config
338-
else:
339-
config_class = litellm.AmazonStabilityConfig
300+
config_class = self.get_config_class(model=model)
340301

341302
config_class.transform_response_dict_to_openai_response(
342303
model_response=model_response,

litellm/utils.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,16 +2631,7 @@ def _check_valid_arg(supported_params):
26312631
):
26322632
optional_params = non_default_params
26332633
elif custom_llm_provider == "bedrock":
2634-
# use stability3 config class if model is a stability3 model
2635-
config_class = (
2636-
litellm.AmazonStability3Config
2637-
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
2638-
else (
2639-
litellm.AmazonNovaCanvasConfig
2640-
if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
2641-
else litellm.AmazonStabilityConfig
2642-
)
2643-
)
2634+
config_class = litellm.BedrockImageGeneration.get_config_class(model=model)
26442635
supported_params = config_class.get_supported_openai_params(model=model)
26452636
_check_valid_arg(supported_params=supported_params)
26462637
optional_params = config_class.map_openai_params(

0 commit comments

Comments
 (0)