Skip to content

Commit ad6a0f4

Browse files
authored
Update perplexity cost tracking (#15743)
* Update perplexity cost tracking * fix lint errors * fix code * fix tests in perplexity * fix test realted to api call * fix exception test
1 parent 396ab80 commit ad6a0f4

File tree

8 files changed

+599
-79
lines changed

8 files changed

+599
-79
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Perplexity chat completion transformations."""

litellm/llms/perplexity/chat/transformation.py

Lines changed: 103 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
1-
"""
2-
Translate from OpenAI's `/v1/chat/completions` to Perplexity's `/v1/chat/completions`
3-
"""
1+
"""Translate from OpenAI's `/v1/chat/completions` to Perplexity's `/v1/chat/completions`."""
42

5-
from typing import Any, List, Optional, Tuple
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
66

7-
import httpx
87
import litellm
98
from litellm._logging import verbose_logger
10-
from litellm.secret_managers.main import get_secret_str
11-
from litellm.types.llms.openai import AllMessageValues
12-
from litellm.types.utils import Usage, PromptTokensDetailsWrapper
13-
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
149
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
15-
from litellm.types.utils import ModelResponse
16-
from litellm.types.llms.openai import ChatCompletionAnnotation
17-
from litellm.types.llms.openai import ChatCompletionAnnotationURLCitation
10+
from litellm.secret_managers.main import get_secret_str
11+
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
12+
13+
if TYPE_CHECKING:
14+
import httpx
15+
16+
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
17+
from litellm.types.llms.openai import (
18+
AllMessageValues,
19+
ChatCompletionAnnotation,
20+
ChatCompletionAnnotationURLCitation,
21+
)
1822

1923

2024
class PerplexityChatConfig(OpenAIGPTConfig):
25+
"""Configuration for Perplexity chat completions."""
26+
2127
@property
22-
def custom_llm_provider(self) -> Optional[str]:
28+
def custom_llm_provider(self) -> str | None:
29+
"""Return the custom LLM provider name."""
2330
return "perplexity"
2431

2532
def _get_openai_compatible_provider_info(
@@ -33,6 +40,38 @@ def _get_openai_compatible_provider_info(
3340
)
3441
return api_base, dynamic_api_key
3542

43+
def validate_environment(
44+
self,
45+
headers: dict,
46+
model: str,
47+
messages: list,
48+
optional_params: dict,
49+
litellm_params: dict,
50+
api_key: Optional[str] = None,
51+
api_base: Optional[str] = None,
52+
) -> dict:
53+
"""Validate Perplexity environment and set headers."""
54+
# Get API key from environment if not provided
55+
if api_key is None:
56+
_, api_key = self._get_openai_compatible_provider_info(
57+
api_base=api_base, api_key=api_key
58+
)
59+
60+
# Validate API key is present
61+
if api_key is None:
62+
raise ValueError(
63+
"The api_key client option must be set either by passing api_key to the client or by setting the PERPLEXITY_API_KEY environment variable"
64+
)
65+
66+
# Set authorization header
67+
headers["Authorization"] = f"Bearer {api_key}"
68+
69+
# Ensure Content-Type is set to application/json
70+
if "content-type" not in headers and "Content-Type" not in headers:
71+
headers["Content-Type"] = "application/json"
72+
73+
return headers
74+
3675
def get_supported_openai_params(self, model: str) -> list:
3776
"""
3877
Perplexity supports a subset of OpenAI params
@@ -72,7 +111,8 @@ def get_supported_openai_params(self, model: str) -> list:
72111

73112
return base_openai_params
74113

75-
def transform_response(
114+
115+
def transform_response( # noqa: PLR0913
76116
self,
77117
model: str,
78118
raw_response: httpx.Response,
@@ -82,10 +122,11 @@ def transform_response(
82122
messages: List[AllMessageValues],
83123
optional_params: dict,
84124
litellm_params: dict,
85-
encoding: Any,
125+
encoding: Any,
86126
api_key: Optional[str] = None,
87-
json_mode: Optional[bool] = None,
127+
json_mode: Optional[bool] = None,
88128
) -> ModelResponse:
129+
"""Transform Perplexity response to standard format."""
89130
# Call the parent transform_response first to handle the standard transformation
90131
model_response = super().transform_response(
91132
model=model,
@@ -104,28 +145,29 @@ def transform_response(
104145
# Extract and enhance usage with Perplexity-specific fields
105146
try:
106147
raw_response_json = raw_response.json()
148+
self.add_cost_to_usage(model_response, raw_response_json)
107149
self._enhance_usage_with_perplexity_fields(
108-
model_response, raw_response_json
150+
model_response, raw_response_json,
109151
)
110152
self._add_citations_as_annotations(model_response, raw_response_json)
111-
except Exception as e:
153+
except (ValueError, TypeError, KeyError) as e:
112154
verbose_logger.debug(f"Error extracting Perplexity-specific usage fields: {e}")
113155

114156
return model_response
115157

116-
def _enhance_usage_with_perplexity_fields(
117-
self, model_response: ModelResponse, raw_response_json: dict
158+
def _enhance_usage_with_perplexity_fields(
159+
self, model_response: ModelResponse, raw_response_json: dict,
118160
) -> None:
119-
"""
120-
Extract citation tokens and search queries from Perplexity API response
121-
and add them to the usage object using standard LiteLLM fields.
161+
"""Extract citation tokens and search queries from Perplexity API response.
162+
163+
Add them to the usage object using standard LiteLLM fields.
122164
"""
123165
if not hasattr(model_response, "usage") or model_response.usage is None:
124166
# Create a usage object if it doesn't exist (when usage was None)
125167
model_response.usage = Usage( # type: ignore[attr-defined]
126168
prompt_tokens=0,
127169
completion_tokens=0,
128-
total_tokens=0
170+
total_tokens=0,
129171
)
130172

131173
usage = model_response.usage # type: ignore[attr-defined]
@@ -146,7 +188,7 @@ def _enhance_usage_with_perplexity_fields(
146188
# Extract search queries count from usage or response metadata
147189
# Perplexity might include this in the usage object or as separate metadata
148190
perplexity_usage = raw_response_json.get("usage", {})
149-
191+
150192
# Try to extract search queries from usage field first, then root level
151193
num_search_queries = perplexity_usage.get("num_search_queries")
152194
if num_search_queries is None:
@@ -155,18 +197,18 @@ def _enhance_usage_with_perplexity_fields(
155197
num_search_queries = perplexity_usage.get("search_queries")
156198
if num_search_queries is None:
157199
num_search_queries = raw_response_json.get("search_queries")
158-
200+
159201
# Create or update prompt_tokens_details to include web search requests and citation tokens
160202
if citation_tokens > 0 or (
161203
num_search_queries is not None and num_search_queries > 0
162204
):
163205
if usage.prompt_tokens_details is None:
164206
usage.prompt_tokens_details = PromptTokensDetailsWrapper()
165-
207+
166208
# Store citation tokens count for cost calculation
167209
if citation_tokens > 0:
168-
setattr(usage, "citation_tokens", citation_tokens)
169-
210+
usage.citation_tokens = citation_tokens
211+
170212
# Store search queries count in the standard web_search_requests field
171213
if num_search_queries is not None and num_search_queries > 0:
172214
usage.prompt_tokens_details.web_search_requests = num_search_queries
@@ -248,4 +290,35 @@ def _add_citations_as_annotations(
248290
if citations:
249291
setattr(model_response, "citations", citations)
250292
if search_results:
251-
setattr(model_response, "search_results", search_results)
293+
setattr(model_response, "search_results", search_results)
294+
295+
def add_cost_to_usage(self, model_response: ModelResponse, raw_response_json: dict) -> None:
296+
"""Add the cost to the usage object."""
297+
try:
298+
usage_data = raw_response_json.get("usage")
299+
if usage_data:
300+
# Try different possible cost field locations
301+
response_cost = None
302+
303+
# Check if cost is directly in usage (flat structure)
304+
if "total_cost" in usage_data:
305+
response_cost = usage_data["total_cost"]
306+
# Check if cost is nested (cost.total_cost structure)
307+
elif "cost" in usage_data and isinstance(usage_data["cost"], dict):
308+
response_cost = usage_data["cost"].get("total_cost")
309+
# Check if cost is a simple value
310+
elif "cost" in usage_data:
311+
response_cost = usage_data["cost"]
312+
313+
if response_cost is not None:
314+
# Store cost in hidden params for the cost calculator to use
315+
if not hasattr(model_response, "_hidden_params"):
316+
model_response._hidden_params = {}
317+
if "additional_headers" not in model_response._hidden_params:
318+
model_response._hidden_params["additional_headers"] = {}
319+
model_response._hidden_params["additional_headers"][
320+
"llm_provider-x-litellm-response-cost"
321+
] = float(response_cost)
322+
except (ValueError, TypeError, KeyError) as e:
323+
verbose_logger.debug(f"Error adding cost to usage: {e}")
324+
# If we can't extract cost, continue without it - don't fail the response

litellm/main.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2033,11 +2033,36 @@ def completion( # type: ignore # noqa: PLR0915
20332033
logging.post_call(
20342034
input=messages, api_key=api_key, original_response=response
20352035
)
2036+
elif custom_llm_provider == "perplexity":
2037+
response = base_llm_http_handler.completion(
2038+
model=model,
2039+
messages=messages,
2040+
headers=headers,
2041+
model_response=model_response,
2042+
api_key=api_key,
2043+
api_base=api_base,
2044+
acompletion=acompletion,
2045+
logging_obj=logging,
2046+
optional_params=optional_params,
2047+
litellm_params=litellm_params,
2048+
shared_session=shared_session,
2049+
timeout=timeout,
2050+
client=client,
2051+
custom_llm_provider=custom_llm_provider,
2052+
encoding=encoding,
2053+
stream=stream,
2054+
provider_config=provider_config,
2055+
)
2056+
2057+
## LOGGING - Call after response has been processed by transform_response
2058+
logging.post_call(
2059+
input=messages, api_key=api_key, original_response=response
2060+
)
2061+
20362062
elif (
20372063
model in litellm.open_ai_chat_completion_models
20382064
or custom_llm_provider == "custom_openai"
20392065
or custom_llm_provider == "deepinfra"
2040-
or custom_llm_provider == "perplexity"
20412066
or custom_llm_provider == "nvidia_nim"
20422067
or custom_llm_provider == "cerebras"
20432068
or custom_llm_provider == "baseten"

tests/llm_translation/test_perplexity_reasoning.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,12 @@ def test_perplexity_reasoning_effort_mock_completion(self, model):
6262
"""
6363
Test that reasoning_effort is correctly passed in actual completion call (mocked)
6464
"""
65-
from openai import OpenAI
66-
from openai.types.chat.chat_completion import ChatCompletion
65+
import httpx
6766

6867
litellm.set_verbose = True
6968

7069
# Mock successful response with reasoning content
71-
response_object = {
70+
response_json = {
7271
"id": "cmpl-test",
7372
"object": "chat.completion",
7473
"created": 1677652288,
@@ -94,35 +93,37 @@ def test_perplexity_reasoning_effort_mock_completion(self, model):
9493
},
9594
}
9695

97-
pydantic_obj = ChatCompletion(**response_object)
98-
99-
def _return_pydantic_obj(*args, **kwargs):
100-
new_response = MagicMock()
101-
new_response.headers = {"content-type": "application/json"}
102-
new_response.parse.return_value = pydantic_obj
103-
return new_response
104-
105-
openai_client = OpenAI(api_key="fake-api-key")
96+
def mock_post(*args, **kwargs):
97+
# Create a mock response
98+
mock_response = MagicMock(spec=httpx.Response)
99+
mock_response.status_code = 200
100+
mock_response.headers = {"content-type": "application/json"}
101+
mock_response.json.return_value = response_json
102+
mock_response.text = json.dumps(response_json)
103+
104+
# Store the request data for verification
105+
mock_post.last_request_data = kwargs.get("data")
106+
if isinstance(mock_post.last_request_data, (str, bytes)):
107+
mock_post.last_request_data = json.loads(mock_post.last_request_data)
108+
109+
return mock_response
106110

107-
with patch.object(
108-
openai_client.chat.completions.with_raw_response, "create", side_effect=_return_pydantic_obj
109-
) as mock_client:
111+
# Mock at the HTTP handler level
112+
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_post) as mock_http:
110113

111114
response = completion(
112115
model=model,
113116
messages=[{"role": "user", "content": "Hello, please think about this carefully."}],
114117
reasoning_effort="high",
115-
client=openai_client,
118+
api_key="fake-api-key",
116119
)
117120

118121
# Verify the call was made
119-
assert mock_client.called
120-
121-
# Get the request data from the mock call
122-
call_args = mock_client.call_args
123-
request_data = call_args.kwargs
122+
assert mock_http.called
124123

125124
# Verify reasoning_effort was included in the request
125+
request_data = mock_post.last_request_data
126+
assert request_data is not None
126127
assert "reasoning_effort" in request_data
127128
assert request_data["reasoning_effort"] == "high"
128129

tests/local_testing/test_completion.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,9 @@ def test_completion_fireworks_ai_dynamic_params(api_key, api_base):
12441244
# @pytest.mark.skip(reason="this test is flaky")
12451245
def test_completion_perplexity_api():
12461246
try:
1247+
import httpx
1248+
import json
1249+
12471250
response_object = {
12481251
"id": "a8f37485-026e-45da-81a9-cf0184896840",
12491252
"model": "llama-3-sonar-small-32k-online",
@@ -1270,25 +1273,17 @@ def test_completion_perplexity_api():
12701273
],
12711274
}
12721275

1273-
from openai import OpenAI
1274-
from openai.types.chat.chat_completion import ChatCompletion
1275-
1276-
pydantic_obj = ChatCompletion(**response_object)
1277-
1278-
def _return_pydantic_obj(*args, **kwargs):
1279-
new_response = MagicMock()
1280-
new_response.headers = {"hello": "world"}
1281-
1282-
new_response.parse.return_value = pydantic_obj
1283-
return new_response
1284-
1285-
openai_client = OpenAI()
1286-
1287-
with patch.object(
1288-
openai_client.chat.completions.with_raw_response,
1289-
"create",
1290-
side_effect=_return_pydantic_obj,
1291-
) as mock_client:
1276+
def mock_post(*args, **kwargs):
1277+
# Create a mock response
1278+
mock_response = MagicMock(spec=httpx.Response)
1279+
mock_response.status_code = 200
1280+
mock_response.headers = {"content-type": "application/json"}
1281+
mock_response.json.return_value = response_object
1282+
mock_response.text = json.dumps(response_object)
1283+
return mock_response
1284+
1285+
# Mock at the HTTP handler level
1286+
with patch("litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_post):
12921287
# litellm.set_verbose= True
12931288
messages = [
12941289
{"role": "system", "content": "You're a good bot"},
@@ -1302,10 +1297,9 @@ def _return_pydantic_obj(*args, **kwargs):
13021297
},
13031298
]
13041299
response = completion(
1305-
model="mistral-7b-instruct",
1300+
model="perplexity/llama-3-sonar-small-32k-online",
13061301
messages=messages,
1307-
api_base="https://api.perplexity.ai",
1308-
client=openai_client,
1302+
api_key="fake-api-key",
13091303
)
13101304
print(response)
13111305
assert hasattr(response, "citations")

0 commit comments

Comments
 (0)