⚡️ Speed up method S3DataSource.get_bucket_tagging by 545%
#619
+34
−10
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 545% (5.45x) speedup for
S3DataSource.get_bucket_tagginginbackend/python/app/sources/external/s3/s3.py⏱️ Runtime :
3.94 milliseconds→611 microseconds(best of173runs)📝 Explanation and details
The optimized version achieves a 545% speedup by eliminating the expensive overhead of repeatedly creating and destroying S3 client connections for each request.
Key optimization: S3 Client Connection Pooling
The original code creates a new S3 client connection for every
get_bucket_tagging()call:The optimized version introduces a cached S3 client via the new
_get_s3_client()method:_s3_client_cached__aenter__()to enter the async context manager once, keeping the connection aliveaclose()method to properly cleanup the cached client when donePerformance Impact Analysis:
From the line profiler results, the original code spent significant time on:
async with session.client('s3'): 40.8% of total time (6.71M nanoseconds)The optimized version eliminates the repeated connection setup overhead, focusing execution time primarily on the actual S3 API calls rather than connection management.
Throughput improvements:
Best use cases:
The optimization excels in scenarios with multiple S3 operations per S3DataSource instance, such as batch processing, concurrent bucket operations, or long-lived services making repeated S3 calls. The cached connection reduces latency and resource overhead for each subsequent operation.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import asyncio # used to run async functions
from unittest.mock import AsyncMock, MagicMock, patch
import pytest # used for our unit tests
from app.sources.external.s3.s3 import S3DataSource
Import the function/class under test
As per the provided code, we need to import S3DataSource, S3Client, S3Response
For the purposes of these tests, we will define minimal versions of S3Response and S3Client
and mock aioboto3, since we do not want to hit real AWS endpoints.
Minimal S3Response for testing
class S3Response:
def init(self, success: bool, data=None, error=None):
self.success = success
self.data = data
self.error = error
Minimal S3Client for testing
class S3Client:
def init(self, session):
self._session = session
The function under test (copied exactly, as required)
try:
import aioboto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
except ImportError:
# For testing, define dummy versions
class ClientError(Exception):
def init(self, response):
self.response = response
---- TESTS ----
Helper to create a mock session and client
class DummyAsyncContextManager:
def init(self, client):
self.client = client
async def aenter(self):
return self.client
async def aexit(self, exc_type, exc, tb):
return None
@pytest.fixture
def mock_s3_client():
# Create a mock session object
mock_session = MagicMock()
# Patch session.client('s3') to return a dummy async context manager
mock_s3_client = MagicMock()
mock_session.client.return_value = DummyAsyncContextManager(mock_s3_client)
return mock_s3_client, mock_session
@pytest.fixture
def datasource(mock_s3_client):
mock_client, mock_session = mock_s3_client
s3_client = S3Client(mock_session)
return S3DataSource(s3_client)
1. Basic Test Cases
@pytest.mark.asyncio
async def test_get_bucket_tagging_success(datasource, mock_s3_client):
"""Test that get_bucket_tagging returns success when S3 returns tags."""
mock_client, _ = mock_s3_client
# Simulate S3 returning tags
expected_tags = {'TagSet': [{'Key': 'env', 'Value': 'prod'}]}
mock_client.get_bucket_tagging = AsyncMock(return_value=expected_tags)
response = await datasource.get_bucket_tagging(Bucket='my-bucket')
@pytest.mark.asyncio
async def test_get_bucket_tagging_with_expected_owner(datasource, mock_s3_client):
"""Test get_bucket_tagging with ExpectedBucketOwner parameter."""
mock_client, _ = mock_s3_client
expected_tags = {'TagSet': [{'Key': 'team', 'Value': 'dev'}]}
mock_client.get_bucket_tagging = AsyncMock(return_value=expected_tags)
response = await datasource.get_bucket_tagging(Bucket='my-bucket', ExpectedBucketOwner='123456789012')
@pytest.mark.asyncio
async def test_get_bucket_tagging_empty_response(datasource, mock_s3_client):
"""Test get_bucket_tagging when S3 returns None (empty response)."""
mock_client, _ = mock_s3_client
mock_client.get_bucket_tagging = AsyncMock(return_value=None)
response = await datasource.get_bucket_tagging(Bucket='my-bucket')
2. Edge Test Cases
@pytest.mark.asyncio
async def test_get_bucket_tagging_error_response(datasource, mock_s3_client):
"""Test get_bucket_tagging when S3 returns an error in the response dict."""
mock_client, _ = mock_s3_client
error_dict = {'Error': {'Code': 'AccessDenied', 'Message': 'You do not have permission'}}
mock_client.get_bucket_tagging = AsyncMock(return_value=error_dict)
response = await datasource.get_bucket_tagging(Bucket='my-bucket')
@pytest.mark.asyncio
async def test_get_bucket_tagging_client_error_exception(datasource, mock_s3_client):
"""Test get_bucket_tagging when S3 raises a ClientError exception."""
mock_client, _ = mock_s3_client
error_response = {'Error': {'Code': 'NoSuchBucket', 'Message': 'The specified bucket does not exist'}}
mock_client.get_bucket_tagging = AsyncMock(side_effect=ClientError(error_response))
response = await datasource.get_bucket_tagging(Bucket='nonexistent-bucket')
@pytest.mark.asyncio
async def test_get_bucket_tagging_unexpected_exception(datasource, mock_s3_client):
"""Test get_bucket_tagging when S3 raises a generic Exception."""
mock_client, _ = mock_s3_client
mock_client.get_bucket_tagging = AsyncMock(side_effect=Exception("Network failure"))
response = await datasource.get_bucket_tagging(Bucket='my-bucket')
@pytest.mark.asyncio
async def test_get_bucket_tagging_concurrent_calls(datasource, mock_s3_client):
"""Test concurrent execution of get_bucket_tagging."""
mock_client, _ = mock_s3_client
# Each call returns different tags
tag_sets = [
{'TagSet': [{'Key': 'env', 'Value': 'prod'}]},
{'TagSet': [{'Key': 'env', 'Value': 'dev'}]},
{'TagSet': [{'Key': 'env', 'Value': 'test'}]},
]
# Use side_effect to return different values
mock_client.get_bucket_tagging = AsyncMock(side_effect=tag_sets)
buckets = ['bucket1', 'bucket2', 'bucket3']
tasks = [datasource.get_bucket_tagging(Bucket=b) for b in buckets]
responses = await asyncio.gather(*tasks)
for i, resp in enumerate(responses):
pass
3. Large Scale Test Cases
@pytest.mark.asyncio
async def test_get_bucket_tagging_many_concurrent(datasource, mock_s3_client):
"""Test get_bucket_tagging with many concurrent calls (scalability)."""
mock_client, _ = mock_s3_client
# Return the same tags for all
expected_tags = {'TagSet': [{'Key': 'scale', 'Value': 'test'}]}
mock_client.get_bucket_tagging = AsyncMock(return_value=expected_tags)
buckets = [f'bucket-{i}' for i in range(50)] # 50 concurrent calls
tasks = [datasource.get_bucket_tagging(Bucket=b) for b in buckets]
responses = await asyncio.gather(*tasks)
for resp in responses:
pass
@pytest.mark.asyncio
async def test_get_bucket_tagging_large_data_response(datasource, mock_s3_client):
"""Test get_bucket_tagging with a large TagSet in response."""
mock_client, _ = mock_s3_client
# Simulate a large number of tags
large_tagset = {'TagSet': [{'Key': f'k{i}', 'Value': f'v{i}'} for i in range(500)]}
mock_client.get_bucket_tagging = AsyncMock(return_value=large_tagset)
response = await datasource.get_bucket_tagging(Bucket='large-bucket')
4. Throughput Test Cases
@pytest.mark.asyncio
async def test_get_bucket_tagging_throughput_small_load(datasource, mock_s3_client):
"""Throughput: Test small load of get_bucket_tagging calls."""
mock_client, _ = mock_s3_client
expected_tags = {'TagSet': [{'Key': 'env', 'Value': 'prod'}]}
mock_client.get_bucket_tagging = AsyncMock(return_value=expected_tags)
buckets = [f'bucket-{i}' for i in range(5)]
tasks = [datasource.get_bucket_tagging(Bucket=b) for b in buckets]
responses = await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_get_bucket_tagging_throughput_medium_load(datasource, mock_s3_client):
"""Throughput: Test medium load of get_bucket_tagging calls."""
mock_client, _ = mock_s3_client
expected_tags = {'TagSet': [{'Key': 'env', 'Value': 'prod'}]}
mock_client.get_bucket_tagging = AsyncMock(return_value=expected_tags)
buckets = [f'bucket-{i}' for i in range(20)]
tasks = [datasource.get_bucket_tagging(Bucket=b) for b in buckets]
responses = await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_get_bucket_tagging_throughput_high_volume(datasource, mock_s3_client):
"""Throughput: Test high volume (but bounded) load of get_bucket_tagging calls."""
mock_client, _ = mock_s3_client
expected_tags = {'TagSet': [{'Key': 'env', 'Value': 'prod'}]}
mock_client.get_bucket_tagging = AsyncMock(return_value=expected_tags)
buckets = [f'bucket-{i}' for i in range(100)]
tasks = [datasource.get_bucket_tagging(Bucket=b) for b in buckets]
responses = await asyncio.gather(*tasks)
# Check that all responses are S3Response and error is None
for resp in responses:
pass
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import asyncio # used to run async functions
from typing import Optional
import pytest # used for our unit tests
from app.sources.external.s3.s3 import S3DataSource
class FakeAioboto3Session:
"""Fake aioboto3.Session stub for testing."""
def init(self, client_factory):
self._client_factory = client_factory
class FakeS3RESTClientViaAccessKey:
"""Fake S3RESTClientViaAccessKey for S3Client stub."""
def init(self, session):
self._session = session
class S3Client:
"""Stub S3Client for testing."""
def init(self, client):
self.client = client
class ClientError(Exception):
"""Stub for botocore.exceptions.ClientError."""
def init(self, error_response, operation_name):
self.response = error_response
self.operation_name = operation_name
--- Fake S3 client for testing async context and method ---
class FakeS3Client:
"""Fake async S3 client for testing."""
def init(self, response=None, raise_client_error=False, raise_exception=False):
self._response = response
self._raise_client_error = raise_client_error
self._raise_exception = raise_exception
--- Fixtures for test setup ---
@pytest.fixture
def s3_data_source_success():
"""Fixture for a S3DataSource that always returns a successful tagging response."""
# Simulate a successful tagging response
response = {
"TagSet": [
{"Key": "env", "Value": "prod"},
{"Key": "team", "Value": "analytics"}
]
}
session = FakeAioboto3Session(lambda: FakeS3Client(response=response))
client = S3Client(FakeS3RESTClientViaAccessKey(session))
return S3DataSource(client)
@pytest.fixture
def s3_data_source_error():
"""Fixture for a S3DataSource that raises a ClientError."""
session = FakeAioboto3Session(lambda: FakeS3Client(raise_client_error=True))
client = S3Client(FakeS3RESTClientViaAccessKey(session))
return S3DataSource(client)
@pytest.fixture
def s3_data_source_exception():
"""Fixture for a S3DataSource that raises a generic Exception."""
session = FakeAioboto3Session(lambda: FakeS3Client(raise_exception=True))
client = S3Client(FakeS3RESTClientViaAccessKey(session))
return S3DataSource(client)
@pytest.fixture
def s3_data_source_empty():
"""Fixture for a S3DataSource that returns None (empty response)."""
session = FakeAioboto3Session(lambda: FakeS3Client(response=None))
client = S3Client(FakeS3RESTClientViaAccessKey(session))
return S3DataSource(client)
@pytest.fixture
def s3_data_source_error_dict():
"""Fixture for a S3DataSource that returns a dict with Error key."""
response = {
"Error": {"Code": "AccessDenied", "Message": "You do not have permission"}
}
session = FakeAioboto3Session(lambda: FakeS3Client(response=response))
client = S3Client(FakeS3RESTClientViaAccessKey(session))
return S3DataSource(client)
--- Basic Test Cases ---
@pytest.mark.asyncio
async def test_get_bucket_tagging_success(s3_data_source_success):
"""Basic test: successful tagging response."""
result = await s3_data_source_success.get_bucket_tagging(Bucket="my-bucket")
@pytest.mark.asyncio
async def test_get_bucket_tagging_success_with_expected_owner(s3_data_source_success):
"""Basic test: successful tagging response with ExpectedBucketOwner."""
result = await s3_data_source_success.get_bucket_tagging(Bucket="my-bucket", ExpectedBucketOwner="123456789012")
--- Edge Test Cases ---
@pytest.mark.asyncio
async def test_get_bucket_tagging_client_error(s3_data_source_error):
"""Edge case: S3 ClientError is handled and returns error response."""
result = await s3_data_source_error.get_bucket_tagging(Bucket="nonexistent-bucket")
@pytest.mark.asyncio
async def test_get_bucket_tagging_exception(s3_data_source_exception):
"""Edge case: Unexpected exception is handled and returns error response."""
result = await s3_data_source_exception.get_bucket_tagging(Bucket="my-bucket")
@pytest.mark.asyncio
async def test_get_bucket_tagging_empty_response(s3_data_source_empty):
"""Edge case: S3 returns None (empty response)."""
result = await s3_data_source_empty.get_bucket_tagging(Bucket="my-bucket")
@pytest.mark.asyncio
async def test_get_bucket_tagging_error_dict(s3_data_source_error_dict):
"""Edge case: S3 returns a dict with Error key."""
result = await s3_data_source_error_dict.get_bucket_tagging(Bucket="my-bucket")
@pytest.mark.asyncio
async def test_get_bucket_tagging_concurrent_success(s3_data_source_success):
"""Edge case: concurrent execution of multiple successful calls."""
# Run 5 concurrent calls
buckets = [f"bucket-{i}" for i in range(5)]
coros = [s3_data_source_success.get_bucket_tagging(Bucket=b) for b in buckets]
results = await asyncio.gather(*coros)
for result in results:
pass
@pytest.mark.asyncio
async def test_get_bucket_tagging_concurrent_mixed(s3_data_source_success, s3_data_source_error):
"""Edge case: concurrent execution with both success and error responses."""
coros = [
s3_data_source_success.get_bucket_tagging(Bucket="good-bucket"),
s3_data_source_error.get_bucket_tagging(Bucket="bad-bucket"),
]
results = await asyncio.gather(*coros)
--- Large Scale Test Cases ---
@pytest.mark.asyncio
async def test_get_bucket_tagging_many_concurrent_success(s3_data_source_success):
"""Large scale test: many concurrent successful calls (up to 50)."""
buckets = [f"bucket-{i}" for i in range(50)]
coros = [s3_data_source_success.get_bucket_tagging(Bucket=b) for b in buckets]
results = await asyncio.gather(*coros)
for result in results:
pass
@pytest.mark.asyncio
async def test_get_bucket_tagging_many_concurrent_errors(s3_data_source_error):
"""Large scale test: many concurrent error responses (up to 20)."""
buckets = [f"bad-bucket-{i}" for i in range(20)]
coros = [s3_data_source_error.get_bucket_tagging(Bucket=b) for b in buckets]
results = await asyncio.gather(*coros)
for result in results:
pass
--- Throughput Test Cases ---
@pytest.mark.asyncio
async def test_get_bucket_tagging_throughput_small_load(s3_data_source_success):
"""Throughput test: small load of 10 concurrent requests."""
buckets = [f"bucket-{i}" for i in range(10)]
coros = [s3_data_source_success.get_bucket_tagging(Bucket=b) for b in buckets]
results = await asyncio.gather(*coros)
for result in results:
pass
@pytest.mark.asyncio
async def test_get_bucket_tagging_throughput_medium_load(s3_data_source_success):
"""Throughput test: medium load of 50 concurrent requests."""
buckets = [f"bucket-{i}" for i in range(50)]
coros = [s3_data_source_success.get_bucket_tagging(Bucket=b) for b in buckets]
results = await asyncio.gather(*coros)
for result in results:
pass
@pytest.mark.asyncio
async def test_get_bucket_tagging_throughput_high_volume(s3_data_source_success):
"""Throughput test: high volume of 100 concurrent requests (upper bound for fast test)."""
buckets = [f"bucket-{i}" for i in range(100)]
coros = [s3_data_source_success.get_bucket_tagging(Bucket=b) for b in buckets]
results = await asyncio.gather(*coros)
for result in results:
pass
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-S3DataSource.get_bucket_tagging-mhwzjfdfand push.