⚡️ Speed up function ui_get_spend_by_tags by 17%
#429
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 17% (0.17x) speedup for
ui_get_spend_by_tagsinlitellm/proxy/spend_tracking/spend_management_endpoints.py⏱️ Runtime :
2.01 milliseconds→1.72 milliseconds(best of114runs)📝 Explanation and details
The optimization achieves a 16% runtime speedup by replacing
collections.defaultdictwith plain dictionaries and optimizing data aggregation patterns.Key optimizations:
Eliminated defaultdict overhead: Replaced
collections.defaultdict(float)andcollections.defaultdict(int)with plaindict[str, float]anddict[str, int]. The line profiler shows defaultdict creation taking ~700ns total (lines withcollections.defaultdict), which is completely eliminated.Optimized aggregation logic:
dict.get(key, default) + valueinstead of defaultdict's+=operator, reducing per-key lookup overhead+=operations entirelyReduced dictionary method calls: The line profiler shows the original code's
total_spend_per_tag[tag_name] += tag_spendtaking ~368ns per hit, while the optimized version'sget()+ assignment pattern is more efficient for this access pattern.Performance impact: Plain dictionaries have lower per-operation overhead than defaultdict for this use case where we know the access patterns upfront. The optimization is particularly effective for the "specific-tags" code path where direct assignment replaces unnecessary aggregation operations.
Test case benefits: The throughput improvement of 0.9% (95,760 vs 94,920 ops/sec) shows consistent gains across concurrent workloads. The optimization performs well across all test scenarios - from basic single calls to high-load concurrent execution with 500+ simultaneous requests.
This optimization maintains identical behavior while reducing CPU overhead in data structure operations, making it especially valuable for high-frequency spend tracking operations.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import asyncio # used to run async functions
function to test
import collections
from typing import TYPE_CHECKING, Any, List, Optional
import pytest # used for our unit tests
from fastapi import HTTPException
from litellm.proxy.spend_tracking.spend_management_endpoints import
ui_get_spend_by_tags
class DummyDB:
"""
Dummy DB class to simulate async query_raw behavior.
"""
def init(self, response_mapping):
self.response_mapping = response_mapping
self.calls = []
class DummyPrismaClient:
"""
Dummy PrismaClient to simulate prisma_client.db.query_raw
"""
def init(self, response_mapping):
self.db = DummyDB(response_mapping)
from litellm.proxy.spend_tracking.spend_management_endpoints import
ui_get_spend_by_tags
########## UNIT TESTS ##########
1. Basic Test Cases
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_basic_all_tags():
"""
Test basic functionality for 'all-tags' (should aggregate all tags).
"""
# Simulate two tags in DB response
response_mapping = {
"all_tags": [
{"individual_request_tag": "tagA", "spend_date": "2024-06-01", "log_count": 5, "total_spend": 10.25},
{"individual_request_tag": "tagB", "spend_date": "2024-06-01", "log_count": 3, "total_spend": 5.75},
]
}
prisma_client = DummyPrismaClient(response_mapping)
result = await ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "all-tags")
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_basic_specific_tags():
"""
Test basic functionality for specific tags.
"""
# Simulate DB response for specific tags
response_mapping = {
"specific_tags": [
{"individual_request_tag": "tagB", "log_count": 7, "total_spend": 14.0},
{"individual_request_tag": "tagC", "log_count": 2, "total_spend": 2.5},
]
}
prisma_client = DummyPrismaClient(response_mapping)
result = await ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "tagB,tagC")
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_basic_empty_tags_str():
"""
Test with empty tags_str (should behave as all-tags).
"""
response_mapping = {
"all_tags": [
{"individual_request_tag": "tagA", "spend_date": "2024-06-01", "log_count": 1, "total_spend": 2.0}
]
}
prisma_client = DummyPrismaClient(response_mapping)
result = await ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "")
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_basic_none_tags_str():
"""
Test with tags_str=None (should behave as all-tags).
"""
response_mapping = {
"all_tags": [
{"individual_request_tag": "tagA", "spend_date": "2024-06-01", "log_count": 4, "total_spend": 8.0}
]
}
prisma_client = DummyPrismaClient(response_mapping)
result = await ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, None)
2. Edge Test Cases
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_no_db_connected():
"""
Test when prisma_client is None (should raise HTTPException).
"""
with pytest.raises(HTTPException) as exc_info:
await ui_get_spend_by_tags("2024-06-01", "2024-06-30", None, "tagA")
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_all_tags_in_tags_list_with_other_tags():
"""
Test when tags_str includes 'all-tags' and other tags (should treat as all-tags).
"""
response_mapping = {
"all_tags": [
{"individual_request_tag": "tagA", "spend_date": "2024-06-01", "log_count": 2, "total_spend": 4.0},
{"individual_request_tag": "tagB", "spend_date": "2024-06-01", "log_count": 1, "total_spend": 2.0},
]
}
prisma_client = DummyPrismaClient(response_mapping)
result = await ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "all-tags,tagA,tagB")
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_tag_with_zero_spend_and_count():
"""
Test tags with zero spend and zero log_count.
"""
response_mapping = {
"specific_tags": [
{"individual_request_tag": "tagZero", "log_count": 0, "total_spend": 0.0},
]
}
prisma_client = DummyPrismaClient(response_mapping)
result = await ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "tagZero")
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_empty_db_response():
"""
Test when DB returns empty list (no tags found).
"""
response_mapping = {
"all_tags": []
}
prisma_client = DummyPrismaClient(response_mapping)
result = await ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "all-tags")
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_concurrent_execution():
"""
Test concurrent execution of multiple calls.
"""
response_mapping = {
"all_tags": [
{"individual_request_tag": "tagA", "spend_date": "2024-06-01", "log_count": 1, "total_spend": 2.0}
],
"specific_tags": [
{"individual_request_tag": "tagB", "log_count": 2, "total_spend": 3.0}
]
}
prisma_client = DummyPrismaClient(response_mapping)
# Run two calls concurrently
results = await asyncio.gather(
ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "all-tags"),
ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "tagB")
)
3. Large Scale Test Cases
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_large_concurrent_calls():
"""
Test many concurrent calls to the function.
"""
# Simulate 50 concurrent calls with different tags
response_mapping = {
"specific_tags": [
{"individual_request_tag": "tagX", "log_count": 5, "total_spend": 10.0}
]
}
prisma_client = DummyPrismaClient(response_mapping)
coros = [
ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "tagX")
for _ in range(50)
]
results = await asyncio.gather(*coros)
for result in results:
pass
4. Throughput Test Cases
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_throughput_small_load():
"""
Throughput test with small load (10 concurrent calls).
"""
response_mapping = {
"specific_tags": [
{"individual_request_tag": "tagSmall", "log_count": 2, "total_spend": 3.0}
]
}
prisma_client = DummyPrismaClient(response_mapping)
coros = [
ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "tagSmall")
for _ in range(10)
]
results = await asyncio.gather(*coros)
for result in results:
pass
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_throughput_medium_load():
"""
Throughput test with medium load (50 concurrent calls).
"""
response_mapping = {
"specific_tags": [
{"individual_request_tag": "tagMedium", "log_count": 7, "total_spend": 14.0}
]
}
prisma_client = DummyPrismaClient(response_mapping)
coros = [
ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "tagMedium")
for _ in range(50)
]
results = await asyncio.gather(*coros)
for result in results:
pass
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_throughput_high_load():
"""
Throughput test with high load (100 concurrent calls).
"""
response_mapping = {
"specific_tags": [
{"individual_request_tag": "tagHigh", "log_count": 20, "total_spend": 40.0}
]
}
prisma_client = DummyPrismaClient(response_mapping)
coros = [
ui_get_spend_by_tags("2024-06-01", "2024-06-30", prisma_client, "tagHigh")
for _ in range(100)
]
results = await asyncio.gather(*coros)
for result in results:
pass
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import asyncio # used to run async functions
function to test
(copied exactly as provided)
import collections
from typing import TYPE_CHECKING, Any, List, Optional
import pytest # used for our unit tests
from fastapi import HTTPException
from litellm.proxy._types import TYPE_CHECKING, Any, List, Optional
from litellm.proxy.spend_tracking.spend_management_endpoints import
ui_get_spend_by_tags
from litellm.proxy.utils import PrismaClient
if TYPE_CHECKING:
from litellm.proxy.proxy_server import PrismaClient
else:
PrismaClient = Any
from litellm.proxy.spend_tracking.spend_management_endpoints import
ui_get_spend_by_tags
---- TESTS ----
Minimal mock for prisma_client.db.query_raw
class MockDB:
def init(self, response_map=None):
self.response_map = response_map or {}
class MockPrismaClient:
def init(self, response_map=None):
self.db = MockDB(response_map)
----------- BASIC TEST CASES -----------
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_all_tags_basic():
"""
Basic: Test spend for all tags (tags_str contains 'all-tags')
"""
mock_response = [
{"individual_request_tag": "tagA", "spend_date": "2024-06-01", "log_count": 5, "total_spend": 10.0},
{"individual_request_tag": "tagB", "spend_date": "2024-06-01", "log_count": 3, "total_spend": 5.5},
{"individual_request_tag": "tagC", "spend_date": "2024-06-01", "log_count": 2, "total_spend": 4.0},
]
prisma_client = MockPrismaClient(response_map={
("2024-06-01", "2024-06-30"): mock_response
})
result = await ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str="all-tags"
)
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_empty_tags_str_returns_all_tags():
"""
Basic: Test empty tags_str returns spend for all tags
"""
mock_response = [
{"individual_request_tag": "tagX", "spend_date": "2024-06-01", "log_count": 1, "total_spend": 2.0},
{"individual_request_tag": "tagY", "spend_date": "2024-06-01", "log_count": 2, "total_spend": 3.0},
]
prisma_client = MockPrismaClient(response_map={
("2024-06-01", "2024-06-30"): mock_response
})
result = await ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str=""
)
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_none_tags_str_returns_all_tags():
"""
Basic: Test None tags_str returns spend for all tags
"""
mock_response = [
{"individual_request_tag": "tag1", "spend_date": "2024-06-01", "log_count": 1, "total_spend": 1.0},
]
prisma_client = MockPrismaClient(response_map={
("2024-06-01", "2024-06-30"): mock_response
})
result = await ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str=None
)
----------- EDGE TEST CASES -----------
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_no_prisma_client_raises():
"""
Edge: Should raise HTTPException if prisma_client is None
"""
with pytest.raises(HTTPException) as exc_info:
await ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=None,
tags_str="tagA"
)
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_concurrent_execution():
"""
Edge: Test concurrent execution of multiple calls
"""
mock_response1 = [
{"individual_request_tag": "tagA", "spend_date": "2024-06-01", "log_count": 2, "total_spend": 4.0},
]
mock_response2 = [
{"individual_request_tag": "tagB", "spend_date": "2024-06-02", "log_count": 3, "total_spend": 6.0},
]
prisma_client = MockPrismaClient(response_map={
("2024-06-01", "2024-06-30"): mock_response1,
("2024-06-02", "2024-06-30"): mock_response2,
})
# Run two calls concurrently
results = await asyncio.gather(
ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str="all-tags"
),
ui_get_spend_by_tags(
start_date="2024-06-02",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str="all-tags"
),
)
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_empty_response():
"""
Edge: Test empty response from db
"""
prisma_client = MockPrismaClient(response_map={
("2024-06-01", "2024-06-30"): [],
})
result = await ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str="all-tags"
)
----------- LARGE SCALE TEST CASES -----------
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_throughput_small_load():
"""
Throughput: Test small load (10 concurrent requests)
"""
mock_response = [
{"individual_request_tag": "tagA", "spend_date": "2024-06-01", "log_count": 1, "total_spend": 2.0},
]
prisma_client = MockPrismaClient(response_map={
("2024-06-01", "2024-06-30"): mock_response
})
tasks = [
ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str="all-tags"
)
for _ in range(10)
]
results = await asyncio.gather(*tasks)
for result in results:
pass
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_throughput_medium_load():
"""
Throughput: Test medium load (100 concurrent requests)
"""
mock_response = [
{"individual_request_tag": "tagB", "spend_date": "2024-06-01", "log_count": 2, "total_spend": 4.0},
]
prisma_client = MockPrismaClient(response_map={
("2024-06-01", "2024-06-30"): mock_response
})
tasks = [
ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str="all-tags"
)
for _ in range(100)
]
results = await asyncio.gather(*tasks)
for result in results:
pass
@pytest.mark.asyncio
async def test_ui_get_spend_by_tags_throughput_large_load():
"""
Throughput: Test large load (500 concurrent requests)
"""
mock_response = [
{"individual_request_tag": "tagC", "spend_date": "2024-06-01", "log_count": 3, "total_spend": 6.0},
]
prisma_client = MockPrismaClient(response_map={
("2024-06-01", "2024-06-30"): mock_response
})
tasks = [
ui_get_spend_by_tags(
start_date="2024-06-01",
end_date="2024-06-30",
prisma_client=prisma_client,
tags_str="all-tags"
)
for _ in range(500)
]
results = await asyncio.gather(*tasks)
for result in results:
pass
@pytest.mark.asyncio
To edit these changes
git checkout codeflash/optimize-ui_get_spend_by_tags-mhu1rwwoand push.