@@ -112,7 +112,8 @@ def _check_whylabs_condition_count_uncompound() -> bool:
112112 else :
113113 logger .info (f"Got response code { response .status_code } but expected 200, so running uncompound" )
114114 except Exception :
115- logger .warning ("Error trying to read whylabs config, falling back to defaults for uncompounding" )
115+ pass
116+
116117 _WHYLABS_SKIP_CONFIG_READ = True
117118 return True
118119
@@ -573,6 +574,65 @@ def _write_segmented_reference_result_set(self, file: SegmentedResultSet, **kwar
573574 else :
574575 return False , "Failed to upload all segments"
575576
577+ def _flatten_tags (self , tags : Union [List , Dict ]) -> List [SegmentTag ]:
578+ if type (tags [0 ]) == list :
579+ result : List [SegmentTag ] = []
580+ for t in tags :
581+ result .append (self ._flatten_tags (t ))
582+ return result
583+
584+ return [SegmentTag (t ["key" ], t ["value" ]) for t in tags ]
585+
586+ def _write_segmented_result_set_transaction (self , file : SegmentedResultSet , ** kwargs : Any ) -> Tuple [bool , str ]:
587+ utc_now = datetime .datetime .now (datetime .timezone .utc )
588+
589+ files = file .get_writables ()
590+ partitions = file .partitions
591+ if len (partitions ) > 1 :
592+ logger .warning (
593+ "SegmentedResultSet contains more than one partition. Only the first partition will be uploaded. "
594+ )
595+ partition = partitions [0 ]
596+ whylabs_tags = list ()
597+ for view in files :
598+ view_tags = list ()
599+ dataset_timestamp = view .dataset_timestamp or utc_now
600+ if view .partition .id != partition .id :
601+ continue
602+ _ , segment_tags , _ = _generate_segment_tags_metadata (view .segment , view .partition )
603+ for segment_tag in segment_tags :
604+ tag_key = segment_tag .key .replace ("whylogs.tag." , "" )
605+ tag_value = segment_tag .value
606+ view_tags .append ({"key" : tag_key , "value" : tag_value })
607+ whylabs_tags .append (view_tags )
608+ stamp = dataset_timestamp .timestamp ()
609+ dataset_timestamp_epoch = int (stamp * 1000 )
610+
611+ region = os .getenv ("WHYLABS_UPLOAD_REGION" , None )
612+ client : TransactionsApi = self ._get_or_create_transaction_client ()
613+ messages : List [str ] = list ()
614+ and_status : bool = True
615+ for view , tags in zip (files , self ._flatten_tags (whylabs_tags )):
616+ with tempfile .NamedTemporaryFile () as tmp_file :
617+ view .write (file = tmp_file )
618+ tmp_file .flush ()
619+ tmp_file .seek (0 )
620+ request = TransactionLogRequest (
621+ dataset_timestamp = dataset_timestamp_epoch , segment_tags = tags , region = region
622+ )
623+ result : AsyncLogResponse = client .log_transaction (self ._transaction_id , request , ** kwargs )
624+ logger .info (f"Added profile { result .id } to transaction { self ._transaction_id } " )
625+ bool_status , message = self ._do_upload (
626+ dataset_timestamp = dataset_timestamp_epoch ,
627+ upload_url = result .upload_url ,
628+ profile_id = result .id ,
629+ profile_file = tmp_file ,
630+ )
631+ and_status = and_status and bool_status
632+ messages .append (message )
633+
634+ return and_status , "; " .join (messages )
635+
576636 def _write_segmented_result_set (self , file : SegmentedResultSet , ** kwargs : Any ) -> Tuple [bool , str ]:
577637 """Put segmented result set for the specified dataset.
578638
@@ -585,6 +645,9 @@ def _write_segmented_result_set(self, file: SegmentedResultSet, **kwargs: Any) -
585645 -------
586646 Tuple[bool, str]
587647 """
648+ if self ._transaction_id is not None :
649+ return self ._write_segmented_result_set_transaction (file , ** kwargs )
650+
588651 # multi-profile writer
589652 files = file .get_writables ()
590653 messages : List [str ] = list ()
@@ -617,39 +680,46 @@ def _get_or_create_transaction_client(self) -> TransactionsApi:
617680 self ._refresh_client ()
618681 return TransactionsApi (self ._api_client )
619682
620- def start_transaction (self , ** kwargs ) -> None :
683+ def start_transaction (self , transaction_id : Optional [ str ] = None , ** kwargs ) -> str :
621684 """
622685 Initiates a transaction -- any profiles subsequently written by calling write()
623- will be uploaded to WhyLabs atomically when commit_transaction() is called. Throws
686+ will be uploaded to WhyLabs, but not ingested until commit_transaction() is called. Throws
624687 on failure.
625688 """
626689 if self ._transaction_id is not None :
627690 logger .error ("Must end current transaction with commit_transaction() before starting another" )
628- return
691+ return self . _transaction_id
629692
630693 if kwargs .get ("dataset_id" ) is not None :
631694 self ._dataset_id = kwargs .get ("dataset_id" )
632695
696+ if transaction_id is not None :
697+ self ._transaction_id = transaction_id # type: ignore
698+ return transaction_id
699+
633700 client : TransactionsApi = self ._get_or_create_transaction_client ()
634701 request = TransactionStartRequest (dataset_id = self ._dataset_id )
635702 result : LogTransactionMetadata = client .start_transaction (request , ** kwargs )
636703 self ._transaction_id = result ["transaction_id" ]
637704 logger .info (f"Starting transaction { self ._transaction_id } , expires { result ['expiration_time' ]} " )
705+ return self ._transaction_id # type: ignore
638706
639707 def commit_transaction (self , ** kwargs ) -> None :
640708 """
641- Atomically upload any profiles written since the previous start_transaction().
709+ Ingest any profiles written since the previous start_transaction().
642710 Throws on failure.
643711 """
644712 if self ._transaction_id is None :
645713 logger .error ("Must call start_transaction() before commit_transaction()" )
646714 return
647715
648- logger .info (f"Committing transaction { self ._transaction_id } " )
716+ id = self ._transaction_id
717+ self ._transaction_id = None
718+ logger .info (f"Committing transaction { id } " )
649719 client = self ._get_or_create_transaction_client ()
650720 request = TransactionCommitRequest (verbose = True )
651- client . commit_transaction ( self . _transaction_id , request , ** kwargs )
652- self . _transaction_id = None
721+ # We abandon the transaction if this throws
722+ client . commit_transaction ( id , request , ** kwargs )
653723
654724 @deprecated_alias (profile = "file" )
655725 def write (self , file : Writable , ** kwargs : Any ) -> Tuple [bool , str ]:
@@ -1079,3 +1149,14 @@ def _get_upload_url(self, dataset_timestamp: int) -> Tuple[str, str]:
10791149 logger .debug (f"Replaced URL with our private domain. New URL: { upload_url } " )
10801150
10811151 return upload_url , profile_id
1152+
1153+
1154+ class WhyLabsTransaction :
1155+ def __init__ (self , writer : WhyLabsWriter ):
1156+ self ._writer = writer
1157+
1158+ def __enter__ (self ) -> None :
1159+ self ._writer .start_transaction ()
1160+
1161+ def __exit__ (self , exc_type , exc_value , exc_tb ) -> None :
1162+ self ._writer .commit_transaction ()
0 commit comments