⚡️ Speed up method S3DataSource.restore_object by 60%
#633
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 60% (0.60x) speedup for
S3DataSource.restore_objectinbackend/python/app/sources/external/s3/s3.py⏱️ Runtime :
1.57 milliseconds→983 microseconds(best of234runs)📝 Explanation and details
The optimized code achieves a 60% runtime improvement by implementing async client caching to eliminate redundant connection setup overhead in the S3 restore operation.
Key optimization applied:
_s3_async_clientinstance variable and_get_s3_async_client()method that caches the async S3 client after first creationasync with session.client('s3')context for every call, the client is created once and reusedPerformance impact analysis:
The line profiler shows the critical bottleneck was in the original
restore_objectmethod:async with session.client('s3')took 14.9% of execution time (1.63M ns)await self._get_s3_async_client()takes only 24.7% but with much lower absolute time (936K ns)Why this works:
In async Python applications, aioboto3 client creation involves connection handshakes and authentication setup that's expensive to repeat. The cached client maintains the connection pool and authentication state, allowing subsequent S3 operations to bypass this setup cost entirely. The
__aenter__()call on the session client is performed once and cached, eliminating the repeated async context manager entry/exit cycle.Workload benefits:
This optimization is particularly effective for:
The optimization maintains full async compatibility and error handling while providing substantial performance gains for any application making multiple S3 restore operations through the same
S3DataSourceinstance.✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import asyncio # used to run async functions
from typing import Any, Dict, Optional
import pytest # used for our unit tests
from app.sources.external.s3.s3 import S3DataSource
Minimal ClientError mock for testing
class ClientError(Exception):
def init(self, response, operation_name):
self.response = response
self.operation_name = operation_name
super().init(str(response))
--- Minimal aioboto3 session/client mocks for testing ---
class FakeS3Client:
def init(self, behavior=None):
# behavior: dict mapping method name to function
self.behavior = behavior or {}
class FakeSession:
def init(self, client_behavior=None):
self.client_behavior = client_behavior
class FakeS3ClientWrapper:
def init(self, session):
self._session = session
--- TESTS ---
1. Basic Test Cases
@pytest.mark.asyncio
async def test_restore_object_basic_success():
"""Test basic async restore_object returns expected S3Response on success."""
# Setup: Session returns a FakeS3Client that returns a dict
session = FakeSession()
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
@pytest.mark.asyncio
async def test_restore_object_basic_optional_params():
"""Test restore_object with all optional parameters set."""
session = FakeSession()
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
@pytest.mark.asyncio
async def test_restore_object_basic_none_response():
"""Test restore_object when S3 client returns None (empty response)."""
async def restore_object_none(**kwargs):
return None
session = FakeSession(client_behavior={'restore_object': restore_object_none})
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
result = await datasource.restore_object('bucket', 'key')
2. Edge Test Cases
@pytest.mark.asyncio
async def test_restore_object_error_response_dict():
"""Test restore_object returns error if S3 response dict contains 'Error'."""
async def restore_object_error(**kwargs):
return {'Error': {'Code': '404', 'Message': 'Not Found'}}
session = FakeSession(client_behavior={'restore_object': restore_object_error})
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
result = await datasource.restore_object('bucket', 'key')
@pytest.mark.asyncio
async def test_restore_object_clienterror_exception():
"""Test restore_object handles ClientError exception and returns error S3Response."""
async def restore_object_raise(**kwargs):
raise ClientError({'Error': {'Code': '403', 'Message': 'Forbidden'}}, 'restore_object')
session = FakeSession(client_behavior={'restore_object': restore_object_raise})
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
result = await datasource.restore_object('bucket', 'key')
@pytest.mark.asyncio
async def test_restore_object_unexpected_exception():
"""Test restore_object handles unexpected exceptions."""
async def restore_object_raise(**kwargs):
raise RuntimeError("Some unexpected error")
session = FakeSession(client_behavior={'restore_object': restore_object_raise})
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
result = await datasource.restore_object('bucket', 'key')
@pytest.mark.asyncio
async def test_restore_object_concurrent_execution():
"""Test concurrent execution of restore_object to ensure async correctness."""
session = FakeSession()
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
# Run 10 concurrent restore_object calls with different keys
tasks = [
datasource.restore_object('bucket', f'key_{i}')
for i in range(10)
]
results = await asyncio.gather(*tasks)
for i, result in enumerate(results):
pass
@pytest.mark.asyncio
async def test_restore_object_response_not_dict():
"""Test restore_object when S3 returns a non-dict response."""
async def restore_object_str(**kwargs):
return "RestoredObject"
session = FakeSession(client_behavior={'restore_object': restore_object_str})
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
result = await datasource.restore_object('bucket', 'key')
3. Large Scale Test Cases
@pytest.mark.asyncio
async def test_restore_object_large_scale_concurrent():
"""Test restore_object under large scale concurrent load (100 calls)."""
session = FakeSession()
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
num_calls = 100
tasks = [
datasource.restore_object('bucket', f'key_{i}', VersionId=f'v{i}')
for i in range(num_calls)
]
results = await asyncio.gather(*tasks)
for i, result in enumerate(results):
pass
@pytest.mark.asyncio
async def test_restore_object_large_scale_error_mix():
"""Test restore_object with a mix of success and error responses in concurrent calls."""
async def restore_object_mixed(**kwargs):
if kwargs['Key'].endswith('9'):
# Simulate error for every 10th key
return {'Error': {'Code': '500', 'Message': 'Internal Error'}}
return {'Restored': True, 'Input': kwargs}
session = FakeSession(client_behavior={'restore_object': restore_object_mixed})
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
num_calls = 50
tasks = [
datasource.restore_object('bucket', f'key_{i}')
for i in range(num_calls)
]
results = await asyncio.gather(*tasks)
for i, result in enumerate(results):
if i % 10 == 9:
pass
else:
pass
4. Throughput Test Cases
@pytest.mark.asyncio
async def test_restore_object_throughput_small_load():
"""Throughput test: restore_object with small load (5 calls)."""
session = FakeSession()
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
tasks = [
datasource.restore_object('bucket', f'key_{i}')
for i in range(5)
]
results = await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_restore_object_throughput_medium_load():
"""Throughput test: restore_object with medium load (50 calls)."""
session = FakeSession()
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
tasks = [
datasource.restore_object('bucket', f'key_{i}')
for i in range(50)
]
results = await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_restore_object_throughput_high_volume():
"""Throughput test: restore_object with high volume (200 calls)."""
session = FakeSession()
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
tasks = [
datasource.restore_object('bucket', f'key_{i}')
for i in range(200)
]
results = await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_restore_object_throughput_mixed_load():
"""Throughput test: restore_object with mixed success/error responses."""
async def restore_object_mixed(**kwargs):
if kwargs['Key'].endswith('7'):
raise ClientError({'Error': {'Code': '429', 'Message': 'Rate Limit'}}, 'restore_object')
return {'Restored': True, 'Input': kwargs}
session = FakeSession(client_behavior={'restore_object': restore_object_mixed})
s3_client_wrapper = FakeS3ClientWrapper(session)
datasource = S3DataSource(s3_client_wrapper)
tasks = [
datasource.restore_object('bucket', f'key_{i}')
for i in range(30)
]
results = await asyncio.gather(*tasks)
for i, result in enumerate(results):
if str(i).endswith('7'):
pass
else:
pass
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import asyncio # used to run async functions
import pytest # used for our unit tests
from app.sources.external.s3.s3 import S3DataSource
Mocks for aioboto3 and botocore exceptions
class MockS3Client:
# Simulate S3 client with async restore_object method
async def restore_object(self, **kwargs):
# Simulate different responses based on input
bucket = kwargs.get('Bucket')
key = kwargs.get('Key')
version_id = kwargs.get('VersionId')
restore_request = kwargs.get('RestoreRequest')
request_payer = kwargs.get('RequestPayer')
checksum_algorithm = kwargs.get('ChecksumAlgorithm')
expected_bucket_owner = kwargs.get('ExpectedBucketOwner')
class MockSession:
# Simulate aioboto3.Session
async def aenter(self):
return MockS3Client()
async def aexit(self, exc_type, exc, tb):
pass
def client(self, service_name):
return self
class MockS3RESTClientViaAccessKey:
def get_session(self):
return MockSession()
class MockClientError(Exception):
def init(self, response, operation_name):
self.response = response
self.operation_name = operation_name
S3Client class as per provided code
class S3Client:
def init(self, client):
self.client = client
def get_session(self):
return self.client.get_session()
----------- UNIT TESTS BELOW ------------
@pytest.fixture
def s3_data_source():
# Provide a S3DataSource instance with mocked S3Client
client = MockS3RESTClientViaAccessKey()
s3_client = S3Client(client)
return S3DataSource(s3_client)
1. Basic Test Cases
@pytest.mark.asyncio
async def test_restore_object_basic_success(s3_data_source):
"""Test basic successful restore_object call."""
result = await s3_data_source.restore_object(Bucket="valid-bucket", Key="valid-key")
@pytest.mark.asyncio
async def test_restore_object_basic_error_response(s3_data_source):
"""Test restore_object returns error for S3 API error response."""
result = await s3_data_source.restore_object(Bucket="error-bucket", Key="any-key")
@pytest.mark.asyncio
async def test_restore_object_basic_none_response(s3_data_source):
"""Test restore_object handles None response gracefully."""
result = await s3_data_source.restore_object(Bucket="none-bucket", Key="any-key")
@pytest.mark.asyncio
async def test_restore_object_with_optional_parameters(s3_data_source):
"""Test restore_object with all optional parameters set."""
restore_request = {"Days": 5, "GlacierJobParameters": {"Tier": "Standard"}}
result = await s3_data_source.restore_object(
Bucket="valid-bucket",
Key="valid-key",
VersionId="v1",
RestoreRequest=restore_request,
RequestPayer="requester",
ChecksumAlgorithm="SHA256",
ExpectedBucketOwner="owner123"
)
2. Edge Test Cases
@pytest.mark.asyncio
async def test_restore_object_client_error(s3_data_source):
"""Test restore_object handles botocore ClientError exception."""
result = await s3_data_source.restore_object(Bucket="clienterror-bucket", Key="any-key")
@pytest.mark.asyncio
async def test_restore_object_unexpected_exception(s3_data_source):
"""Test restore_object handles unexpected exceptions gracefully."""
result = await s3_data_source.restore_object(Bucket="exception-bucket", Key="any-key")
@pytest.mark.asyncio
async def test_restore_object_concurrent_execution(s3_data_source):
"""Test concurrent restore_object calls with different buckets/keys."""
buckets = ["valid-bucket", "error-bucket", "none-bucket"]
keys = ["key1", "key2", "key3"]
coros = [s3_data_source.restore_object(Bucket=b, Key=k) for b, k in zip(buckets, keys)]
results = await asyncio.gather(*coros)
@pytest.mark.asyncio
async def test_restore_object_edge_optional_params(s3_data_source):
"""Test restore_object with only some optional parameters set."""
result = await s3_data_source.restore_object(
Bucket="valid-bucket",
Key="valid-key",
RestoreRequest={"Days": 3},
RequestPayer="requester"
)
3. Large Scale Test Cases
@pytest.mark.asyncio
async def test_restore_object_large_payload(s3_data_source):
"""Test restore_object returns large data payload correctly."""
result = await s3_data_source.restore_object(Bucket="large-bucket", Key="large-key")
@pytest.mark.asyncio
async def test_restore_object_many_concurrent_calls(s3_data_source):
"""Test restore_object with many concurrent calls."""
# Limit to 50 to keep test fast and bounded
coros = [s3_data_source.restore_object(Bucket="valid-bucket", Key=f"key-{i}") for i in range(50)]
results = await asyncio.gather(*coros)
for idx, result in enumerate(results):
pass
4. Throughput Test Cases
@pytest.mark.asyncio
async def test_restore_object_throughput_small_load(s3_data_source):
"""Throughput test: small load of 5 concurrent restore_object calls."""
coros = [s3_data_source.restore_object(Bucket="valid-bucket", Key=f"key-{i}") for i in range(5)]
results = await asyncio.gather(*coros)
@pytest.mark.asyncio
async def test_restore_object_throughput_medium_load(s3_data_source):
"""Throughput test: medium load of 20 concurrent restore_object calls."""
coros = [s3_data_source.restore_object(Bucket="valid-bucket", Key=f"key-{i}") for i in range(20)]
results = await asyncio.gather(*coros)
@pytest.mark.asyncio
async def test_restore_object_throughput_high_volume(s3_data_source):
"""Throughput test: high volume load of 100 concurrent restore_object calls."""
coros = [s3_data_source.restore_object(Bucket="valid-bucket", Key=f"key-{i}") for i in range(100)]
results = await asyncio.gather(*coros)
# Validate keys
for i, result in enumerate(results):
pass
@pytest.mark.asyncio
async def test_restore_object_throughput_mixed_load(s3_data_source):
"""Throughput test: mixed load with valid, error, and none buckets."""
buckets = ["valid-bucket"] * 40 + ["error-bucket"] * 5 + ["none-bucket"] * 5
keys = [f"key-{i}" for i in range(50)]
coros = [s3_data_source.restore_object(Bucket=b, Key=k) for b, k in zip(buckets, keys)]
results = await asyncio.gather(*coros)
for i, result in enumerate(results):
if buckets[i] == "valid-bucket":
pass
else:
pass
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-S3DataSource.restore_object-mhxcj11mand push.