Skip to content

Commit c7d0af0

Browse files
Merge pull request #2426 from BerriAI/litellm_whisper_cost_tracking
feat: add cost tracking + caching for `/audio/transcription` calls
2 parents 34ad958 + 1d15dde commit c7d0af0

File tree

11 files changed

+247
-41
lines changed

11 files changed

+247
-41
lines changed

litellm/caching.py

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import litellm
1111
import time, logging, asyncio
1212
import json, traceback, ast, hashlib
13-
from typing import Optional, Literal, List, Union, Any
13+
from typing import Optional, Literal, List, Union, Any, BinaryIO
1414
from openai._models import BaseModel as OpenAIObject
1515
from litellm._logging import verbose_logger
1616

@@ -765,8 +765,24 @@ def __init__(
765765
password: Optional[str] = None,
766766
similarity_threshold: Optional[float] = None,
767767
supported_call_types: Optional[
768-
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
769-
] = ["completion", "acompletion", "embedding", "aembedding"],
768+
List[
769+
Literal[
770+
"completion",
771+
"acompletion",
772+
"embedding",
773+
"aembedding",
774+
"atranscription",
775+
"transcription",
776+
]
777+
]
778+
] = [
779+
"completion",
780+
"acompletion",
781+
"embedding",
782+
"aembedding",
783+
"atranscription",
784+
"transcription",
785+
],
770786
# s3 Bucket, boto3 configuration
771787
s3_bucket_name: Optional[str] = None,
772788
s3_region_name: Optional[str] = None,
@@ -881,9 +897,14 @@ def get_cache_key(self, *args, **kwargs):
881897
"input",
882898
"encoding_format",
883899
] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
884-
900+
transcription_only_kwargs = [
901+
"file",
902+
"language",
903+
]
885904
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
886-
combined_kwargs = completion_kwargs + embedding_only_kwargs
905+
combined_kwargs = (
906+
completion_kwargs + embedding_only_kwargs + transcription_only_kwargs
907+
)
887908
for param in combined_kwargs:
888909
# ignore litellm params here
889910
if param in kwargs:
@@ -915,6 +936,17 @@ def get_cache_key(self, *args, **kwargs):
915936
param_value = (
916937
caching_group or model_group or kwargs[param]
917938
) # use caching_group, if set then model_group if it exists, else use kwargs["model"]
939+
elif param == "file":
940+
metadata_file_name = kwargs.get("metadata", {}).get(
941+
"file_name", None
942+
)
943+
litellm_params_file_name = kwargs.get("litellm_params", {}).get(
944+
"file_name", None
945+
)
946+
if metadata_file_name is not None:
947+
param_value = metadata_file_name
948+
elif litellm_params_file_name is not None:
949+
param_value = litellm_params_file_name
918950
else:
919951
if kwargs[param] is None:
920952
continue # ignore None params
@@ -1144,8 +1176,24 @@ def enable_cache(
11441176
port: Optional[str] = None,
11451177
password: Optional[str] = None,
11461178
supported_call_types: Optional[
1147-
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
1148-
] = ["completion", "acompletion", "embedding", "aembedding"],
1179+
List[
1180+
Literal[
1181+
"completion",
1182+
"acompletion",
1183+
"embedding",
1184+
"aembedding",
1185+
"atranscription",
1186+
"transcription",
1187+
]
1188+
]
1189+
] = [
1190+
"completion",
1191+
"acompletion",
1192+
"embedding",
1193+
"aembedding",
1194+
"atranscription",
1195+
"transcription",
1196+
],
11491197
**kwargs,
11501198
):
11511199
"""
@@ -1193,8 +1241,24 @@ def update_cache(
11931241
port: Optional[str] = None,
11941242
password: Optional[str] = None,
11951243
supported_call_types: Optional[
1196-
List[Literal["completion", "acompletion", "embedding", "aembedding"]]
1197-
] = ["completion", "acompletion", "embedding", "aembedding"],
1244+
List[
1245+
Literal[
1246+
"completion",
1247+
"acompletion",
1248+
"embedding",
1249+
"aembedding",
1250+
"atranscription",
1251+
"transcription",
1252+
]
1253+
]
1254+
] = [
1255+
"completion",
1256+
"acompletion",
1257+
"embedding",
1258+
"aembedding",
1259+
"atranscription",
1260+
"transcription",
1261+
],
11981262
**kwargs,
11991263
):
12001264
"""

litellm/llms/azure.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,8 @@ def audio_transcriptions(
861861
additional_args={"complete_input_dict": data},
862862
original_response=stringified_response,
863863
)
864-
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
864+
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
865+
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
865866
return final_response
866867

867868
async def async_audio_transcriptions(
@@ -921,7 +922,8 @@ async def async_audio_transcriptions(
921922
},
922923
original_response=stringified_response,
923924
)
924-
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
925+
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
926+
response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
925927
return response
926928
except Exception as e:
927929
## LOGGING

litellm/llms/openai.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ def image_generation(
753753
# return response
754754
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
755755
except OpenAIError as e:
756+
756757
exception_mapping_worked = True
757758
## LOGGING
758759
logging_obj.post_call(
@@ -824,7 +825,8 @@ def audio_transcriptions(
824825
additional_args={"complete_input_dict": data},
825826
original_response=stringified_response,
826827
)
827-
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
828+
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
829+
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
828830
return final_response
829831

830832
async def async_audio_transcriptions(
@@ -862,7 +864,8 @@ async def async_audio_transcriptions(
862864
additional_args={"complete_input_dict": data},
863865
original_response=stringified_response,
864866
)
865-
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="audio_transcription") # type: ignore
867+
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
868+
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
866869
except Exception as e:
867870
## LOGGING
868871
logging_obj.post_call(

litellm/proxy/proxy_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3295,6 +3295,7 @@ async def audio_transcriptions(
32953295
user_api_key_dict, "team_id", None
32963296
)
32973297
data["metadata"]["endpoint"] = str(request.url)
3298+
data["metadata"]["file_name"] = file.filename
32983299

32993300
### TEAM-SPECIFIC PARAMS ###
33003301
if user_api_key_dict.team_id is not None:
@@ -3329,7 +3330,7 @@ async def audio_transcriptions(
33293330
data = await proxy_logging_obj.pre_call_hook(
33303331
user_api_key_dict=user_api_key_dict,
33313332
data=data,
3332-
call_type="moderation",
3333+
call_type="audio_transcription",
33333334
)
33343335

33353336
## ROUTE TO CORRECT ENDPOINT ##

litellm/proxy/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ async def pre_call_hook(
9696
user_api_key_dict: UserAPIKeyAuth,
9797
data: dict,
9898
call_type: Literal[
99-
"completion", "embeddings", "image_generation", "moderation"
99+
"completion",
100+
"embeddings",
101+
"image_generation",
102+
"moderation",
103+
"audio_transcription",
100104
],
101105
):
102106
"""

litellm/tests/test_completion_cost.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
) # Adds the parent directory to the system path
77
import time
88
import litellm
9-
from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models
9+
from litellm import (
10+
get_max_tokens,
11+
model_cost,
12+
open_ai_chat_completion_models,
13+
TranscriptionResponse,
14+
)
1015
import pytest
1116

1217

@@ -238,3 +243,57 @@ def test_cost_bedrock_pricing_actual_calls():
238243
messages=[{"role": "user", "content": "Hey, how's it going?"}],
239244
)
240245
assert cost > 0
246+
247+
248+
def test_whisper_openai():
249+
litellm.set_verbose = True
250+
transcription = TranscriptionResponse(
251+
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
252+
)
253+
transcription._hidden_params = {
254+
"model": "whisper-1",
255+
"custom_llm_provider": "openai",
256+
"optional_params": {},
257+
"model_id": None,
258+
}
259+
_total_time_in_seconds = 3
260+
261+
transcription._response_ms = _total_time_in_seconds * 1000
262+
cost = litellm.completion_cost(model="whisper-1", completion_response=transcription)
263+
264+
print(f"cost: {cost}")
265+
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
266+
expected_cost = round(
267+
litellm.model_cost["whisper-1"]["output_cost_per_second"]
268+
* _total_time_in_seconds,
269+
5,
270+
)
271+
assert cost == expected_cost
272+
273+
274+
def test_whisper_azure():
275+
litellm.set_verbose = True
276+
transcription = TranscriptionResponse(
277+
text="Four score and seven years ago, our fathers brought forth on this continent a new nation, conceived in liberty and dedicated to the proposition that all men are created equal. Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure."
278+
)
279+
transcription._hidden_params = {
280+
"model": "whisper-1",
281+
"custom_llm_provider": "azure",
282+
"optional_params": {},
283+
"model_id": None,
284+
}
285+
_total_time_in_seconds = 3
286+
287+
transcription._response_ms = _total_time_in_seconds * 1000
288+
cost = litellm.completion_cost(
289+
model="azure/azure-whisper", completion_response=transcription
290+
)
291+
292+
print(f"cost: {cost}")
293+
print(f"whisper dict: {litellm.model_cost['whisper-1']}")
294+
expected_cost = round(
295+
litellm.model_cost["whisper-1"]["output_cost_per_second"]
296+
* _total_time_in_seconds,
297+
5,
298+
)
299+
assert cost == expected_cost

litellm/tests/test_custom_callback_input.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ def test_image_generation_openai():
973973

974974
print(f"customHandler_success.errors: {customHandler_success.errors}")
975975
print(f"customHandler_success.states: {customHandler_success.states}")
976+
time.sleep(2)
976977
assert len(customHandler_success.errors) == 0
977978
assert len(customHandler_success.states) == 3 # pre, post, success
978979
# test failure callback

litellm/tests/test_custom_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_ti
100100
def test_async_chat_openai_stream():
101101
try:
102102
tmp_function = TmpFunction()
103-
# litellm.set_verbose = True
103+
litellm.set_verbose = True
104104
litellm.success_callback = [tmp_function.async_test_logging_fn]
105105
complete_streaming_response = ""
106106

litellm/tests/test_proxy_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def test_load_router_config():
336336
"acompletion",
337337
"embedding",
338338
"aembedding",
339+
"atranscription",
340+
"transcription",
339341
] # init with all call types
340342

341343
litellm.disable_cache()

0 commit comments

Comments
 (0)