Skip to content

Commit 06b15fd

Browse files
Transaction updates (#1464)
* More transaction integration tests * support for distributed transactions * Transaction context manager - [ ] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md) and the [Code of Conduct](CODE_OF_CONDUCT.md).
1 parent 926af60 commit 06b15fd

File tree

2 files changed

+225
-9
lines changed

2 files changed

+225
-9
lines changed

python/tests/api/writer/test_whylabs_integration.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import logging
12
import os
23
import time
34
from uuid import uuid4
45

6+
import numpy as np
57
import pandas as pd
68
import pytest
79
from whylabs_client.api.dataset_profile_api import DatasetProfileApi
@@ -13,7 +15,7 @@
1315
)
1416

1517
import whylogs as why
16-
from whylogs.api.writer.whylabs import WhyLabsWriter
18+
from whylogs.api.writer.whylabs import WhyLabsTransaction, WhyLabsWriter
1719
from whylogs.core import DatasetProfileView
1820
from whylogs.core.feature_weights import FeatureWeights
1921
from whylogs.core.schema import DatasetSchema
@@ -28,6 +30,8 @@
2830

2931
SLEEP_TIME = 30
3032

33+
logger = logging.getLogger(__name__)
34+
3135

3236
@pytest.mark.load
3337
def test_whylabs_writer():
@@ -212,3 +216,134 @@ def test_transactions():
212216
downloaded_profile = writer._s3_pool.request("GET", download_url, headers=headers, timeout=writer._timeout_seconds)
213217
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
214218
assert deserialized_view.get_columns().keys() == data.keys()
219+
220+
221+
@pytest.mark.load
222+
def test_transaction_context():
223+
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
224+
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
225+
why.init(force_local=True)
226+
schema = DatasetSchema()
227+
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
228+
df = pd.read_csv(csv_url)
229+
pdfs = np.array_split(df, 7)
230+
writer = WhyLabsWriter(dataset_id=MODEL_ID)
231+
tids = list()
232+
try:
233+
with WhyLabsTransaction(writer):
234+
for data in pdfs:
235+
trace_id = str(uuid4())
236+
tids.append(trace_id)
237+
result = why.log(data, schema=schema, trace_id=trace_id)
238+
status, id = writer.write(result.profile())
239+
if not status:
240+
raise Exception() # or retry the profile...
241+
242+
except Exception:
243+
# The start_transaction() or commit_transaction() in the
244+
# WhyLabsTransaction context manager will throw on failure.
245+
# Or retry the commit
246+
logger.exception("Logging transaction failed")
247+
248+
time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
249+
dataset_api = DatasetProfileApi(writer._api_client)
250+
for trace_id in tids:
251+
response: ProfileTracesResponse = dataset_api.get_profile_traces(
252+
org_id=ORG_ID,
253+
dataset_id=MODEL_ID,
254+
trace_id=trace_id,
255+
)
256+
download_url = response.get("traces")[0]["download_url"]
257+
headers = {"Content-Type": "application/octet-stream"}
258+
downloaded_profile = writer._s3_pool.request(
259+
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
260+
)
261+
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
262+
assert deserialized_view is not None
263+
264+
265+
@pytest.mark.load
266+
def test_transaction_segmented():
267+
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
268+
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
269+
why.init(force_local=True)
270+
schema = DatasetSchema(segments=segment_on_column("Gender"))
271+
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
272+
data = pd.read_csv(csv_url)
273+
writer = WhyLabsWriter(dataset_id=MODEL_ID)
274+
trace_id = str(uuid4())
275+
try:
276+
writer.start_transaction()
277+
result = why.log(data, schema=schema, trace_id=trace_id)
278+
status, id = writer.write(result)
279+
if not status:
280+
raise Exception() # or retry the profile...
281+
282+
except Exception:
283+
# The start_transaction() or commit_transaction() in the
284+
# WhyLabsTransaction context manager will throw on failure.
285+
# Or retry the commit
286+
logger.exception("Logging transaction failed")
287+
288+
writer.commit_transaction()
289+
time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
290+
dataset_api = DatasetProfileApi(writer._api_client)
291+
response: ProfileTracesResponse = dataset_api.get_profile_traces(
292+
org_id=ORG_ID,
293+
dataset_id=MODEL_ID,
294+
trace_id=trace_id,
295+
)
296+
assert len(response.get("traces")) == 2
297+
for trace in response.get("traces"):
298+
download_url = trace.get("download_url")
299+
headers = {"Content-Type": "application/octet-stream"}
300+
downloaded_profile = writer._s3_pool.request(
301+
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
302+
)
303+
assert downloaded_profile is not None
304+
305+
306+
@pytest.mark.load
307+
def test_transaction_distributed():
308+
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
309+
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
310+
why.init(force_local=True)
311+
schema = DatasetSchema()
312+
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
313+
df = pd.read_csv(csv_url)
314+
pdfs = np.array_split(df, 7)
315+
writer = WhyLabsWriter(dataset_id=MODEL_ID)
316+
tids = list()
317+
try:
318+
transaction_id = writer.start_transaction()
319+
for data in pdfs: # pretend each iteration is run on a different machine
320+
dist_writer = WhyLabsWriter(dataset_id=MODEL_ID)
321+
dist_writer.start_transaction(transaction_id)
322+
trace_id = str(uuid4())
323+
tids.append(trace_id)
324+
result = why.log(data, schema=schema, trace_id=trace_id)
325+
status, id = dist_writer.write(result.profile())
326+
if not status:
327+
raise Exception() # or retry the profile...
328+
writer.commit_transaction()
329+
except Exception:
330+
# The start_transaction() or commit_transaction() in the
331+
# WhyLabsTransaction context manager will throw on failure.
332+
# Or retry the commit
333+
logger.exception("Logging transaction failed")
334+
335+
time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
336+
dataset_api = DatasetProfileApi(writer._api_client)
337+
for trace_id in tids:
338+
response: ProfileTracesResponse = dataset_api.get_profile_traces(
339+
org_id=ORG_ID,
340+
dataset_id=MODEL_ID,
341+
trace_id=trace_id,
342+
)
343+
download_url = response.get("traces")[0]["download_url"]
344+
headers = {"Content-Type": "application/octet-stream"}
345+
downloaded_profile = writer._s3_pool.request(
346+
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
347+
)
348+
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
349+
assert deserialized_view is not None

python/whylogs/api/writer/whylabs.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def _check_whylabs_condition_count_uncompound() -> bool:
112112
else:
113113
logger.info(f"Got response code {response.status_code} but expected 200, so running uncompound")
114114
except Exception:
115-
logger.warning("Error trying to read whylabs config, falling back to defaults for uncompounding")
115+
pass
116+
116117
_WHYLABS_SKIP_CONFIG_READ = True
117118
return True
118119

@@ -573,6 +574,65 @@ def _write_segmented_reference_result_set(self, file: SegmentedResultSet, **kwar
573574
else:
574575
return False, "Failed to upload all segments"
575576

577+
def _flatten_tags(self, tags: Union[List, Dict]) -> List[SegmentTag]:
578+
if type(tags[0]) == list:
579+
result: List[SegmentTag] = []
580+
for t in tags:
581+
result.append(self._flatten_tags(t))
582+
return result
583+
584+
return [SegmentTag(t["key"], t["value"]) for t in tags]
585+
586+
def _write_segmented_result_set_transaction(self, file: SegmentedResultSet, **kwargs: Any) -> Tuple[bool, str]:
587+
utc_now = datetime.datetime.now(datetime.timezone.utc)
588+
589+
files = file.get_writables()
590+
partitions = file.partitions
591+
if len(partitions) > 1:
592+
logger.warning(
593+
"SegmentedResultSet contains more than one partition. Only the first partition will be uploaded. "
594+
)
595+
partition = partitions[0]
596+
whylabs_tags = list()
597+
for view in files:
598+
view_tags = list()
599+
dataset_timestamp = view.dataset_timestamp or utc_now
600+
if view.partition.id != partition.id:
601+
continue
602+
_, segment_tags, _ = _generate_segment_tags_metadata(view.segment, view.partition)
603+
for segment_tag in segment_tags:
604+
tag_key = segment_tag.key.replace("whylogs.tag.", "")
605+
tag_value = segment_tag.value
606+
view_tags.append({"key": tag_key, "value": tag_value})
607+
whylabs_tags.append(view_tags)
608+
stamp = dataset_timestamp.timestamp()
609+
dataset_timestamp_epoch = int(stamp * 1000)
610+
611+
region = os.getenv("WHYLABS_UPLOAD_REGION", None)
612+
client: TransactionsApi = self._get_or_create_transaction_client()
613+
messages: List[str] = list()
614+
and_status: bool = True
615+
for view, tags in zip(files, self._flatten_tags(whylabs_tags)):
616+
with tempfile.NamedTemporaryFile() as tmp_file:
617+
view.write(file=tmp_file)
618+
tmp_file.flush()
619+
tmp_file.seek(0)
620+
request = TransactionLogRequest(
621+
dataset_timestamp=dataset_timestamp_epoch, segment_tags=tags, region=region
622+
)
623+
result: AsyncLogResponse = client.log_transaction(self._transaction_id, request, **kwargs)
624+
logger.info(f"Added profile {result.id} to transaction {self._transaction_id}")
625+
bool_status, message = self._do_upload(
626+
dataset_timestamp=dataset_timestamp_epoch,
627+
upload_url=result.upload_url,
628+
profile_id=result.id,
629+
profile_file=tmp_file,
630+
)
631+
and_status = and_status and bool_status
632+
messages.append(message)
633+
634+
return and_status, "; ".join(messages)
635+
576636
def _write_segmented_result_set(self, file: SegmentedResultSet, **kwargs: Any) -> Tuple[bool, str]:
577637
"""Put segmented result set for the specified dataset.
578638
@@ -585,6 +645,9 @@ def _write_segmented_result_set(self, file: SegmentedResultSet, **kwargs: Any) -
585645
-------
586646
Tuple[bool, str]
587647
"""
648+
if self._transaction_id is not None:
649+
return self._write_segmented_result_set_transaction(file, **kwargs)
650+
588651
# multi-profile writer
589652
files = file.get_writables()
590653
messages: List[str] = list()
@@ -617,39 +680,46 @@ def _get_or_create_transaction_client(self) -> TransactionsApi:
617680
self._refresh_client()
618681
return TransactionsApi(self._api_client)
619682

620-
def start_transaction(self, **kwargs) -> None:
683+
def start_transaction(self, transaction_id: Optional[str] = None, **kwargs) -> str:
621684
"""
622685
Initiates a transaction -- any profiles subsequently written by calling write()
623-
will be uploaded to WhyLabs atomically when commit_transaction() is called. Throws
686+
will be uploaded to WhyLabs, but not ingested until commit_transaction() is called. Throws
624687
on failure.
625688
"""
626689
if self._transaction_id is not None:
627690
logger.error("Must end current transaction with commit_transaction() before starting another")
628-
return
691+
return self._transaction_id
629692

630693
if kwargs.get("dataset_id") is not None:
631694
self._dataset_id = kwargs.get("dataset_id")
632695

696+
if transaction_id is not None:
697+
self._transaction_id = transaction_id # type: ignore
698+
return transaction_id
699+
633700
client: TransactionsApi = self._get_or_create_transaction_client()
634701
request = TransactionStartRequest(dataset_id=self._dataset_id)
635702
result: LogTransactionMetadata = client.start_transaction(request, **kwargs)
636703
self._transaction_id = result["transaction_id"]
637704
logger.info(f"Starting transaction {self._transaction_id}, expires {result['expiration_time']}")
705+
return self._transaction_id # type: ignore
638706

639707
def commit_transaction(self, **kwargs) -> None:
640708
"""
641-
Atomically upload any profiles written since the previous start_transaction().
709+
Ingest any profiles written since the previous start_transaction().
642710
Throws on failure.
643711
"""
644712
if self._transaction_id is None:
645713
logger.error("Must call start_transaction() before commit_transaction()")
646714
return
647715

648-
logger.info(f"Committing transaction {self._transaction_id}")
716+
id = self._transaction_id
717+
self._transaction_id = None
718+
logger.info(f"Committing transaction {id}")
649719
client = self._get_or_create_transaction_client()
650720
request = TransactionCommitRequest(verbose=True)
651-
client.commit_transaction(self._transaction_id, request, **kwargs)
652-
self._transaction_id = None
721+
# We abandon the transaction if this throws
722+
client.commit_transaction(id, request, **kwargs)
653723

654724
@deprecated_alias(profile="file")
655725
def write(self, file: Writable, **kwargs: Any) -> Tuple[bool, str]:
@@ -1079,3 +1149,14 @@ def _get_upload_url(self, dataset_timestamp: int) -> Tuple[str, str]:
10791149
logger.debug(f"Replaced URL with our private domain. New URL: {upload_url}")
10801150

10811151
return upload_url, profile_id
1152+
1153+
1154+
class WhyLabsTransaction:
1155+
def __init__(self, writer: WhyLabsWriter):
1156+
self._writer = writer
1157+
1158+
def __enter__(self) -> None:
1159+
self._writer.start_transaction()
1160+
1161+
def __exit__(self, exc_type, exc_value, exc_tb) -> None:
1162+
self._writer.commit_transaction()

0 commit comments

Comments
 (0)