diff --git a/langextract/annotation.py b/langextract/annotation.py index 4f0ec7bd..4447153d 100644 --- a/langextract/annotation.py +++ b/langextract/annotation.py @@ -35,10 +35,12 @@ from langextract import progress from langextract import prompting from langextract import resolver as resolver_lib +from langextract import retry_utils from langextract.core import base_model from langextract.core import data from langextract.core import exceptions from langextract.core import format_handler as fh +from langextract.core import types as core_types class DocumentRepeatError(exceptions.LangExtractError): @@ -202,6 +204,139 @@ def __init__( "Annotator initialized with format_handler: %s", format_handler ) + def _process_batch_with_retry( + self, + batch_prompts: list[str], + batch: list[chunking.TextChunk], + retry_transient_errors: bool = True, + max_retries: int = 3, + retry_initial_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_max_delay: float = 60.0, + **kwargs, + ) -> Iterator[list[core_types.ScoredOutput]]: + """Process a batch of prompts with individual chunk retry capability. + + This method processes each chunk individually and retries failed chunks + due to transient errors (like 503 "model overloaded") while preserving + successful chunks from the same batch. + + Args: + batch_prompts: List of prompts for the batch + batch: List of TextChunk objects corresponding to the prompts + retry_transient_errors: Whether to retry on transient errors + max_retries: Maximum number of retry attempts + retry_initial_delay: Initial delay before retry + retry_backoff_factor: Backoff multiplier for retries + retry_max_delay: Maximum delay between retries + **kwargs: Additional arguments passed to the language model + + Yields: + Lists of ScoredOutputs, with retries for failed chunks + """ + try: + batch_results = list( + self._language_model.infer( + batch_prompts=batch_prompts, + **kwargs, + ) + ) + + yield from batch_results + return + + except Exception as e: + if not retry_utils.is_transient_error(e): + raise + + logging.warning( + "Batch processing failed with transient error: %s. " + "Falling back to individual chunk processing with retry.", + str(e), + ) + + individual_results = [] + + for i, (prompt, chunk) in enumerate(zip(batch_prompts, batch)): + try: + chunk_result = self._process_single_chunk_with_retry( + prompt=prompt, + chunk=chunk, + retry_transient_errors=retry_transient_errors, + max_retries=max_retries, + retry_initial_delay=retry_initial_delay, + retry_backoff_factor=retry_backoff_factor, + retry_max_delay=retry_max_delay, + **kwargs, + ) + individual_results.append(chunk_result) + + except Exception as e: + logging.error( + "Failed to process chunk %d after retries: %s. " + "Chunk info: document_id=%s, text_length=%d. " + "Stopping document processing.", + i, + str(e), + chunk.document_id, + len(chunk.chunk_text), + ) + raise + + yield from individual_results + + def _process_single_chunk_with_retry( + self, + prompt: str, + chunk: chunking.TextChunk, + retry_transient_errors: bool = True, + max_retries: int = 3, + retry_initial_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_max_delay: float = 60.0, + **kwargs, + ) -> list[core_types.ScoredOutput]: + """Process a single chunk with retry logic. + + Args: + prompt: The prompt for this chunk + chunk: The TextChunk object + retry_transient_errors: Whether to retry on transient errors + max_retries: Maximum number of retry attempts + retry_initial_delay: Initial delay before retry + retry_backoff_factor: Backoff multiplier for retries + retry_max_delay: Maximum delay between retries + **kwargs: Additional arguments for the language model + + Returns: + List containing a single ScoredOutput for this chunk + """ + + # Use the retry decorator with custom parameters + @retry_utils.retry_chunk_processing( + max_retries=max_retries, + initial_delay=retry_initial_delay, + backoff_factor=retry_backoff_factor, + max_delay=retry_max_delay, + enabled=retry_transient_errors, + ) + def _process_chunk(): + batch_results = list( + self._language_model.infer( + batch_prompts=[prompt], + **kwargs, + ) + ) + + if not batch_results: + raise exceptions.InferenceOutputError( + f"No results returned for chunk in document {chunk.document_id}" + ) + + return batch_results[0] + + return _process_chunk() + def annotate_documents( self, documents: Iterable[data.Document], @@ -211,6 +346,11 @@ def annotate_documents( debug: bool = True, extraction_passes: int = 1, show_progress: bool = True, + retry_transient_errors: bool = True, + max_retries: int = 3, + retry_initial_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_max_delay: float = 60.0, **kwargs, ) -> Iterator[data.AnnotatedDocument]: """Annotates a sequence of documents with NLP extractions. @@ -234,6 +374,11 @@ def annotate_documents( Values > 1 reprocess tokens multiple times, potentially increasing costs with the potential for a more thorough extraction. show_progress: Whether to show progress bar. Defaults to True. + retry_transient_errors: Whether to retry on transient errors. Defaults to True. + max_retries: Maximum number of retry attempts. Defaults to 3. + retry_initial_delay: Initial delay before retry in seconds. Defaults to 1.0. + retry_backoff_factor: Backoff multiplier for retries. Defaults to 2.0. + retry_max_delay: Maximum delay between retries in seconds. Defaults to 60.0. **kwargs: Additional arguments passed to LanguageModel.infer and Resolver. Yields: @@ -253,6 +398,11 @@ def annotate_documents( batch_length, debug, show_progress, + retry_transient_errors, + max_retries, + retry_initial_delay, + retry_backoff_factor, + retry_max_delay, **kwargs, ) else: @@ -264,6 +414,11 @@ def annotate_documents( debug, extraction_passes, show_progress, + retry_transient_errors, + max_retries, + retry_initial_delay, + retry_backoff_factor, + retry_max_delay, **kwargs, ) @@ -275,9 +430,32 @@ def _annotate_documents_single_pass( batch_length: int, debug: bool, show_progress: bool = True, + retry_transient_errors: bool = True, + max_retries: int = 3, + retry_initial_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_max_delay: float = 60.0, **kwargs, ) -> Iterator[data.AnnotatedDocument]: - """Single-pass annotation logic (original implementation).""" + """Single-pass annotation logic (original implementation). + + Args: + documents: Iterable of documents to annotate + resolver: Resolver for processing inference results + max_char_buffer: Maximum character buffer for chunking + batch_length: Number of chunks to process in each batch + debug: Whether to enable debug logging + show_progress: Whether to show progress bar + retry_transient_errors: Whether to retry on transient errors + max_retries: Maximum number of retry attempts + retry_initial_delay: Initial delay before retry + retry_backoff_factor: Backoff multiplier for retries + retry_max_delay: Maximum delay between retries + **kwargs: Additional arguments passed to language model + + Yields: + AnnotatedDocument objects with extracted data + """ logging.info("Starting document annotation.") doc_iter, doc_iter_for_chunks = itertools.tee(documents, 2) @@ -321,8 +499,15 @@ def _annotate_documents_single_pass( ) progress_bar.set_description(desc) - batch_scored_outputs = self._language_model.infer( + # Process batch with individual chunk retry capability + batch_scored_outputs = self._process_batch_with_retry( batch_prompts=batch_prompts, + batch=batch, + retry_transient_errors=retry_transient_errors, + max_retries=max_retries, + retry_initial_delay=retry_initial_delay, + retry_backoff_factor=retry_backoff_factor, + retry_max_delay=retry_max_delay, **kwargs, ) @@ -419,9 +604,33 @@ def _annotate_documents_sequential_passes( debug: bool, extraction_passes: int, show_progress: bool = True, + retry_transient_errors: bool = True, + max_retries: int = 3, + retry_initial_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_max_delay: float = 60.0, **kwargs, ) -> Iterator[data.AnnotatedDocument]: - """Sequential extraction passes logic for improved recall.""" + """Sequential extraction passes logic for improved recall. + + Args: + documents: Iterable of documents to annotate + resolver: Resolver for processing inference results + max_char_buffer: Maximum character buffer for chunking + batch_length: Number of chunks to process in each batch + debug: Whether to enable debug logging + extraction_passes: Number of extraction passes to perform + show_progress: Whether to show progress bar + retry_transient_errors: Whether to retry on transient errors + max_retries: Maximum number of retry attempts + retry_initial_delay: Initial delay before retry + retry_backoff_factor: Backoff multiplier for retries + retry_max_delay: Maximum delay between retries + **kwargs: Additional arguments passed to language model + + Yields: + AnnotatedDocument objects with merged extracted data + """ logging.info( "Starting sequential extraction passes for improved recall with %d" @@ -446,6 +655,11 @@ def _annotate_documents_sequential_passes( batch_length, debug=(debug and pass_num == 0), show_progress=show_progress if pass_num == 0 else False, + retry_transient_errors=retry_transient_errors, + max_retries=max_retries, + retry_initial_delay=retry_initial_delay, + retry_backoff_factor=retry_backoff_factor, + retry_max_delay=retry_max_delay, **kwargs, ): doc_id = annotated_doc.document_id @@ -494,6 +708,11 @@ def annotate_text( debug: bool = True, extraction_passes: int = 1, show_progress: bool = True, + retry_transient_errors: bool = True, + max_retries: int = 3, + retry_initial_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_max_delay: float = 60.0, **kwargs, ) -> data.AnnotatedDocument: """Annotates text with NLP extractions for text input. @@ -511,6 +730,11 @@ def annotate_text( standard single extraction. Values > 1 reprocess tokens multiple times, potentially increasing costs. show_progress: Whether to show progress bar. Defaults to True. + retry_transient_errors: Whether to retry on transient errors. Defaults to True. + max_retries: Maximum number of retry attempts. Defaults to 3. + retry_initial_delay: Initial delay before retry in seconds. Defaults to 1.0. + retry_backoff_factor: Backoff multiplier for retries. Defaults to 2.0. + retry_max_delay: Maximum delay between retries in seconds. Defaults to 60.0. **kwargs: Additional arguments for inference and resolver_lib. Returns: @@ -540,6 +764,11 @@ def annotate_text( debug, extraction_passes, show_progress, + retry_transient_errors, + max_retries, + retry_initial_delay, + retry_backoff_factor, + retry_max_delay, **kwargs, ) ) diff --git a/langextract/extraction.py b/langextract/extraction.py index e019ab09..ed2466a9 100644 --- a/langextract/extraction.py +++ b/langextract/extraction.py @@ -59,6 +59,11 @@ def extract( prompt_validation_level: pv.PromptValidationLevel = pv.PromptValidationLevel.WARNING, prompt_validation_strict: bool = False, show_progress: bool = True, + retry_transient_errors: bool = True, + max_retries: int = 3, + retry_initial_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + retry_max_delay: float = 60.0, ) -> typing.Any: """Extracts structured information from text. @@ -150,6 +155,12 @@ def extract( prompt_validation_strict: When True and prompt_validation_level is ERROR, raises on non-exact matches (MATCH_FUZZY, MATCH_LESSER). Defaults to False. show_progress: Whether to show progress bar during extraction. Defaults to True. + retry_transient_errors: Whether to automatically retry on transient errors + like 503 "model overloaded". Defaults to True. + max_retries: Maximum number of retry attempts for transient errors. Defaults to 3. + retry_initial_delay: Initial delay in seconds before first retry. Defaults to 1.0. + retry_backoff_factor: Multiplier for exponential backoff between retries. Defaults to 2.0. + retry_max_delay: Maximum delay between retries in seconds. Defaults to 60.0. Returns: An AnnotatedDocument with the extracted information when input is a @@ -320,6 +331,16 @@ def extract( format_handler=format_handler, ) + # Add retry parameters to alignment kwargs + retry_kwargs = { + "retry_transient_errors": retry_transient_errors, + "max_retries": max_retries, + "retry_initial_delay": retry_initial_delay, + "retry_backoff_factor": retry_backoff_factor, + "retry_max_delay": retry_max_delay, + } + alignment_kwargs.update(retry_kwargs) + if isinstance(text_or_documents, str): return annotator.annotate_text( text=text_or_documents, diff --git a/langextract/providers/gemini.py b/langextract/providers/gemini.py index cb62cfbb..8ceab448 100644 --- a/langextract/providers/gemini.py +++ b/langextract/providers/gemini.py @@ -23,6 +23,7 @@ from absl import logging +from langextract import retry_utils from langextract.core import base_model from langextract.core import data from langextract.core import exceptions @@ -179,6 +180,7 @@ def __init__( k: v for k, v in (kwargs or {}).items() if k in _API_CONFIG_KEYS } + @retry_utils.retry_chunk_processing() def _process_single_prompt( self, prompt: str, config: dict ) -> core_types.ScoredOutput: diff --git a/langextract/retry_utils.py b/langextract/retry_utils.py new file mode 100644 index 00000000..8e7e850d --- /dev/null +++ b/langextract/retry_utils.py @@ -0,0 +1,278 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Retry utilities for handling transient errors in LangExtract.""" + +from __future__ import annotations + +import functools +import random +import time +from typing import Any, Callable, TypeVar + +from absl import logging + +T = TypeVar("T") + +# Transient error patterns that should trigger retries +TRANSIENT_ERROR_PATTERNS = [ + "503", + "service unavailable", + "temporarily unavailable", + "rate limit", + "429", + "too many requests", + "connection reset", + "timeout", + "timed out", + "deadline exceeded", + "model is overloaded", + "quota exceeded", + "resource exhausted", + "internal server error", + "502", + "504", + "gateway timeout", + "bad gateway", +] + +# Exception types that indicate transient errors +TRANSIENT_EXCEPTION_TYPES = [ + "ServiceUnavailable", + "RateLimitError", + "Timeout", + "ConnectionError", + "TimeoutError", + "OSError", + "RuntimeError", +] + + +def is_transient_error(error: Exception) -> bool: + """Check if an error is transient and should be retried. + + Args: + error: The exception to check + + Returns: + True if the error is transient and should be retried + """ + error_str = str(error).lower() + error_type = type(error).__name__ + + # Check for transient error patterns in the error message + is_transient_pattern = any( + pattern in error_str for pattern in TRANSIENT_ERROR_PATTERNS + ) + + # Check for transient exception types + is_transient_type = error_type in TRANSIENT_EXCEPTION_TYPES + return is_transient_pattern or is_transient_type + + +def execute_retry_with_backoff( + attempt: int, + max_retries: int, + delay: float, + max_delay: float, + backoff_factor: float, + error: Exception, + operation_name: str = "operation", +) -> float: + """Execute retry logic with exponential backoff and jitter. + + Args: + attempt: Current attempt number (0-based) + max_retries: Maximum number of retries + delay: Current delay value + max_delay: Maximum delay value + backoff_factor: Factor to multiply delay by + error: The exception that occurred + operation_name: Name of the operation for logging + + Returns: + New delay value for next iteration + """ + if attempt >= max_retries: + logging.error( + "%s failed after %d retries: %s", + operation_name, + max_retries, + str(error), + ) + raise error + + current_delay = min(delay, max_delay) + + jitter_amount = current_delay * 0.1 * random.random() + current_delay += jitter_amount + + logging.warning( + "%s failed on attempt %d/%d due to transient error: %s. " + "Retrying in %.2f seconds...", + operation_name, + attempt + 1, + max_retries + 1, + str(error), + current_delay, + ) + + time.sleep(current_delay) + return min(delay * backoff_factor, max_delay) + + +def retry_on_transient_errors( + max_retries: int = 3, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + max_delay: float = 60.0, + jitter: bool = True, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """Decorator to retry functions on transient errors with exponential backoff. + + Args: + max_retries: Maximum number of retry attempts (default: 3) + initial_delay: Initial delay in seconds (default: 1.0) + backoff_factor: Multiplier for exponential backoff (default: 2.0) + max_delay: Maximum delay between retries in seconds (default: 60.0) + jitter: Whether to add random jitter to prevent thundering herd (default: True) + + Returns: + Decorated function with retry logic + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + last_exception = None + delay = initial_delay + + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + except Exception as e: + last_exception = e + + # If this is a transient error + if not is_transient_error(e): + logging.debug( + "Non-transient error encountered, not retrying: %s", str(e) + ) + raise + + # If we've exhausted retries + if attempt >= max_retries: + logging.warning( + "Max retries (%d) exceeded for transient error: %s", + max_retries, + str(e), + ) + raise + + # Calculate delay with exponential backoff. + current_delay = min(delay, max_delay) + + # Add jitter to prevent thundering herd. + if jitter: + jitter_amount = current_delay * 0.1 * random.random() + current_delay += jitter_amount + + logging.info( + "Transient error on attempt %d/%d: %s. Retrying in %.2f" + " seconds...", + attempt + 1, + max_retries + 1, + str(e), + current_delay, + ) + + time.sleep(current_delay) + delay = min(delay * backoff_factor, max_delay) + + # This should never be reached, but just in case. + if last_exception: + raise last_exception + raise RuntimeError("Retry logic failed unexpectedly") + + return wrapper + + return decorator + + +def retry_chunk_processing( + max_retries: int = 3, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + max_delay: float = 60.0, + enabled: bool = True, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """Specialized retry decorator for chunk processing with chunk-specific logging. + + This is optimized for the annotation process where individual chunks may fail + due to transient errors while other chunks in the same batch succeed. + + Args: + max_retries: Maximum number of retry attempts (default: 3) + initial_delay: Initial delay in seconds (default: 1.0) + backoff_factor: Multiplier for exponential backoff (default: 2.0) + max_delay: Maximum delay between retries in seconds (default: 60.0) + enabled: Whether retry is enabled (default: True) + + Returns: + Decorated function with chunk-specific retry logic + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + # Check if retry is disabled. + if not enabled: + return func(*args, **kwargs) + + last_exception = None + delay = initial_delay + + for attempt in range(max_retries + 1): + try: + return func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Check if this is a transient error + if not is_transient_error(e): + logging.debug( + "Non-transient error in chunk processing, not retrying: %s", + str(e), + ) + raise + + # Execute retry logic with backoff + delay = execute_retry_with_backoff( + attempt=attempt, + max_retries=max_retries, + delay=delay, + max_delay=max_delay, + backoff_factor=backoff_factor, + error=e, + operation_name="Chunk processing", + ) + + # This should never be reached, but just in case + if last_exception: + raise last_exception + raise RuntimeError("Chunk retry logic failed unexpectedly") + + return wrapper + + return decorator diff --git a/tests/annotation_test.py b/tests/annotation_test.py index 78df7dd6..e9fe0878 100644 --- a/tests/annotation_test.py +++ b/tests/annotation_test.py @@ -1118,5 +1118,169 @@ def test_extractions_overlap(self, ext1, ext2, expected): self.assertEqual(result, expected) +class AnnotatorRetryPolicyTest(absltest.TestCase): + """Test retry policy functionality in annotation.""" + + def setUp(self): + super().setUp() + self.mock_language_model = self.enter_context( + mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) + ) + self.annotator = annotation.Annotator( + language_model=self.mock_language_model, + prompt_template=prompting.PromptTemplateStructured(description=""), + ) + + def test_retry_parameters_accepted(self): + """Test that retry parameters are accepted by annotate_documents.""" + documents = [data.Document(text="Test document", document_id="test_doc")] + + mock_result = types.ScoredOutput(score=1.0, output='{"extractions": []}') + self.mock_language_model.infer.return_value = iter([[mock_result]]) + + try: + list( + self.annotator.annotate_documents( + documents=documents, + retry_transient_errors=True, + max_retries=3, + retry_initial_delay=1.0, + retry_backoff_factor=2.0, + retry_max_delay=60.0, + ) + ) + except TypeError as e: + if "unexpected keyword argument" in str(e): + self.fail(f"Retry parameters not accepted: {e}") + else: + raise + + def test_retry_parameters_accepted_annotate_text(self): + """Test that retry parameters are accepted by annotate_text.""" + mock_result = types.ScoredOutput(score=1.0, output='{"extractions": []}') + self.mock_language_model.infer.return_value = iter([[mock_result]]) + + try: + self.annotator.annotate_text( + text="Test text", + retry_transient_errors=True, + max_retries=3, + retry_initial_delay=1.0, + retry_backoff_factor=2.0, + retry_max_delay=60.0, + ) + except TypeError as e: + if "unexpected keyword argument" in str(e): + self.fail(f"Retry parameters not accepted: {e}") + else: + raise + + def test_retry_parameters_default_values(self): + """Test that retry parameters have correct default values.""" + documents = [data.Document(text="Test document", document_id="test_doc")] + + mock_result = types.ScoredOutput(score=1.0, output='{"extractions": []}') + self.mock_language_model.infer.return_value = iter([[mock_result]]) + + try: + list(self.annotator.annotate_documents(documents=documents)) + except TypeError as e: + if "unexpected keyword argument" in str(e): + self.fail(f"Default retry parameters not working: {e}") + else: + raise + + def test_retry_parameters_sequential_passes(self): + """Test that retry parameters work with sequential passes.""" + documents = [data.Document(text="Test document", document_id="test_doc")] + + mock_result = types.ScoredOutput(score=1.0, output='{"extractions": []}') + self.mock_language_model.infer.return_value = iter([[mock_result]]) + + try: + list( + self.annotator.annotate_documents( + documents=documents, + extraction_passes=2, + retry_transient_errors=True, + max_retries=3, + retry_initial_delay=1.0, + retry_backoff_factor=2.0, + retry_max_delay=60.0, + ) + ) + except TypeError as e: + if "unexpected keyword argument" in str(e): + self.fail(f"Retry parameters not accepted in sequential passes: {e}") + else: + raise + + def test_retry_parameters_passed_to_batch_processing(self): + """Test that retry parameters are passed to batch processing methods.""" + documents = [data.Document(text="Test document", document_id="test_doc")] + + mock_result = types.ScoredOutput(score=1.0, output='{"extractions": []}') + self.mock_language_model.infer.return_value = iter([[mock_result]]) + + with mock.patch.object( + self.annotator, + "_process_batch_with_retry", + return_value=iter([[mock_result]]), + ) as mock_batch_processing: + list( + self.annotator.annotate_documents( + documents=documents, + retry_transient_errors=True, + max_retries=5, + retry_initial_delay=2.0, + retry_backoff_factor=1.5, + retry_max_delay=120.0, + ) + ) + + mock_batch_processing.assert_called_once() + call_kwargs = mock_batch_processing.call_args[1] + + self.assertEqual(call_kwargs["retry_transient_errors"], True) + self.assertEqual(call_kwargs["max_retries"], 5) + self.assertEqual(call_kwargs["retry_initial_delay"], 2.0) + self.assertEqual(call_kwargs["retry_backoff_factor"], 1.5) + self.assertEqual(call_kwargs["retry_max_delay"], 120.0) + + def test_retry_parameters_passed_to_single_chunk_processing(self): + """Test that retry parameters are passed to single chunk processing.""" + documents = [data.Document(text="Test document", document_id="test_doc")] + + mock_result = types.ScoredOutput(score=1.0, output='{"extractions": []}') + + with mock.patch.object( + self.annotator, + "_process_single_chunk_with_retry", + return_value=[mock_result], + ) as mock_single_chunk: + self.mock_language_model.infer.side_effect = Exception( + "503 The model is overloaded" + ) + + list( + self.annotator.annotate_documents( + documents=documents, + retry_transient_errors=True, + max_retries=3, + retry_initial_delay=1.0, + retry_backoff_factor=2.0, + retry_max_delay=60.0, + ) + ) + + self.assertTrue(mock_single_chunk.called) + call_kwargs = mock_single_chunk.call_args[1] + self.assertEqual(call_kwargs["retry_transient_errors"], True) + self.assertEqual(call_kwargs["max_retries"], 3) + self.assertEqual(call_kwargs["retry_initial_delay"], 1.0) + self.assertEqual(call_kwargs["retry_backoff_factor"], 2.0) + self.assertEqual(call_kwargs["retry_max_delay"], 60.0) + + if __name__ == "__main__": absltest.main() diff --git a/tests/retry_utils_test.py b/tests/retry_utils_test.py new file mode 100644 index 00000000..91aba85c --- /dev/null +++ b/tests/retry_utils_test.py @@ -0,0 +1,300 @@ +# Copyright 2025 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time + +from absl.testing import absltest + +from langextract import retry_utils +from langextract.core import exceptions + + +class RetryUtilsTest(absltest.TestCase): + """Test retry utility functions.""" + + def test_is_transient_error_503(self): + """Test that 503 errors are identified as transient.""" + error = exceptions.InferenceRuntimeError("503 The model is overloaded") + self.assertTrue(retry_utils.is_transient_error(error)) + + error = exceptions.InferenceRuntimeError("Service temporarily unavailable") + self.assertTrue(retry_utils.is_transient_error(error)) + + def test_is_transient_error_429(self): + """Test that 429 rate limit errors are identified as transient.""" + error = exceptions.InferenceRuntimeError("429 Too Many Requests") + self.assertTrue(retry_utils.is_transient_error(error)) + + error = exceptions.InferenceRuntimeError("Rate limit exceeded") + self.assertTrue(retry_utils.is_transient_error(error)) + + def test_is_transient_error_timeout(self): + """Test that timeout errors are identified as transient.""" + error = TimeoutError("Request timed out") + self.assertTrue(retry_utils.is_transient_error(error)) + + error = exceptions.InferenceRuntimeError("Connection timeout") + self.assertTrue(retry_utils.is_transient_error(error)) + + def test_is_transient_error_non_transient(self): + """Test that non-transient errors are not retried.""" + error = exceptions.InferenceConfigError("Invalid API key") + self.assertFalse(retry_utils.is_transient_error(error)) + + error = exceptions.InferenceRuntimeError("Invalid model ID") + self.assertFalse(retry_utils.is_transient_error(error)) + + def test_retry_decorator_success(self): + """Test that retry decorator works for successful calls.""" + + @retry_utils.retry_on_transient_errors(max_retries=2, initial_delay=0.01) + def successful_function(): + return "success" + + result = successful_function() + self.assertEqual(result, "success") + + def test_retry_decorator_transient_error_success(self): + """Test that retry decorator retries on transient errors and succeeds.""" + call_count = 0 + + @retry_utils.retry_on_transient_errors(max_retries=3, initial_delay=0.01) + def failing_then_successful_function(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise exceptions.InferenceRuntimeError("503 The model is overloaded") + return "success" + + result = failing_then_successful_function() + self.assertEqual(result, "success") + self.assertEqual(call_count, 3) + + def test_retry_decorator_non_transient_error(self): + """Test that retry decorator doesn't retry non-transient errors.""" + call_count = 0 + + @retry_utils.retry_on_transient_errors(max_retries=3, initial_delay=0.01) + def non_transient_failing_function(): + nonlocal call_count + call_count += 1 + raise exceptions.InferenceConfigError("Invalid API key") + + with self.assertRaises(exceptions.InferenceConfigError): + non_transient_failing_function() + + self.assertEqual(call_count, 1) + + def test_retry_decorator_max_retries_exceeded(self): + """Test that retry decorator gives up after max retries.""" + call_count = 0 + + @retry_utils.retry_on_transient_errors(max_retries=2, initial_delay=0.01) + def always_failing_function(): + nonlocal call_count + call_count += 1 + raise exceptions.InferenceRuntimeError("503 The model is overloaded") + + with self.assertRaises(exceptions.InferenceRuntimeError): + always_failing_function() + + self.assertEqual(call_count, 3) + + def test_retry_chunk_processing_disabled(self): + """Test that retry can be disabled.""" + call_count = 0 + + @retry_utils.retry_chunk_processing(enabled=False) + def failing_function(): + nonlocal call_count + call_count += 1 + raise exceptions.InferenceRuntimeError("503 The model is overloaded") + + with self.assertRaises(exceptions.InferenceRuntimeError): + failing_function() + + self.assertEqual(call_count, 1) + + def test_retry_chunk_processing_enabled(self): + """Test that retry works when enabled.""" + call_count = 0 + + @retry_utils.retry_chunk_processing(max_retries=2, initial_delay=0.01) + def failing_then_successful_function(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise exceptions.InferenceRuntimeError("503 The model is overloaded") + return "success" + + result = failing_then_successful_function() + self.assertEqual(result, "success") + self.assertEqual(call_count, 3) + + def test_retry_backoff_timing(self): + """Test that retry uses exponential backoff.""" + call_times = [] + + @retry_utils.retry_on_transient_errors( + max_retries=2, initial_delay=0.1, backoff_factor=2.0, max_delay=1.0 + ) + def timing_test_function(): + call_times.append(time.time()) + if len(call_times) < 3: + raise exceptions.InferenceRuntimeError("503 The model is overloaded") + return "success" + + result = timing_test_function() + + self.assertEqual(result, "success") + self.assertEqual(len(call_times), 3) + + if len(call_times) >= 3: + delay1 = call_times[1] - call_times[0] + delay2 = call_times[2] - call_times[1] + + self.assertGreater(delay1, 0.05) + self.assertGreater(delay2, 0.15) + self.assertGreater(delay2, delay1) + + def test_retry_with_jitter(self): + """Test that retry adds jitter to prevent thundering herd.""" + call_times = [] + + @retry_utils.retry_on_transient_errors( + max_retries=2, initial_delay=0.1, jitter=True + ) + def jitter_test_function(): + call_times.append(time.time()) + if len(call_times) < 3: + raise exceptions.InferenceRuntimeError("503 The model is overloaded") + return "success" + + result = jitter_test_function() + self.assertEqual(result, "success") + self.assertEqual(len(call_times), 3) + + if len(call_times) >= 2: + delay1 = call_times[1] - call_times[0] + self.assertGreater(delay1, 0.05) + self.assertLess(delay1, 0.15) + + def test_retry_max_delay_cap(self): + """Test that retry respects max_delay cap.""" + call_times = [] + + @retry_utils.retry_on_transient_errors( + max_retries=2, initial_delay=0.1, backoff_factor=10.0, max_delay=0.2 + ) + def max_delay_test_function(): + call_times.append(time.time()) + if len(call_times) < 3: + raise exceptions.InferenceRuntimeError("503 The model is overloaded") + return "success" + + result = max_delay_test_function() + self.assertEqual(result, "success") + self.assertEqual(len(call_times), 3) + + if len(call_times) >= 2: + delay1 = call_times[1] - call_times[0] + delay2 = call_times[2] - call_times[1] + + self.assertLess(delay1, 0.3) + self.assertLess(delay2, 0.3) + + def test_error_message_detection(self): + """Test that various error messages are properly detected.""" + error_messages_503 = [ + "503 The model is overloaded", + "503 Service Unavailable", + "The model is overloaded", + "Service temporarily unavailable", + ] + + for msg in error_messages_503: + error = exceptions.InferenceRuntimeError(msg) + self.assertTrue( + retry_utils.is_transient_error(error), + f"Error message '{msg}' should be transient", + ) + + error_messages_429 = [ + "429 Too Many Requests", + "Rate limit exceeded", + "Too many requests", + ] + + for msg in error_messages_429: + error = exceptions.InferenceRuntimeError(msg) + self.assertTrue( + retry_utils.is_transient_error(error), + f"Error message '{msg}' should be transient", + ) + + error_messages_timeout = [ + "Request timed out", + "Connection timeout", + "Timeout", + "Deadline exceeded", + ] + + for msg in error_messages_timeout: + error = exceptions.InferenceRuntimeError(msg) + self.assertTrue( + retry_utils.is_transient_error(error), + f"Error message '{msg}' should be transient", + ) + + error_messages_non_transient = [ + "Invalid API key", + "Invalid model ID", + "Authentication failed", + ] + + for msg in error_messages_non_transient: + error = exceptions.InferenceRuntimeError(msg) + self.assertFalse( + retry_utils.is_transient_error(error), + f"Error message '{msg}' should not be transient", + ) + + def test_retry_decorator_preserves_function_metadata(self): + """Test that retry decorator preserves function metadata.""" + + @retry_utils.retry_on_transient_errors(max_retries=2) + def test_function(): + """Test function docstring.""" + return "test" + + self.assertEqual(test_function.__name__, "test_function") + self.assertEqual(test_function.__doc__, "Test function docstring.") + + def test_retry_chunk_processing_preserves_function_metadata(self): + """Test that chunk retry decorator preserves function metadata.""" + + @retry_utils.retry_chunk_processing(max_retries=2) + def test_chunk_function(): + """Test chunk function docstring.""" + return "test" + + self.assertEqual(test_chunk_function.__name__, "test_chunk_function") + self.assertEqual( + test_chunk_function.__doc__, "Test chunk function docstring." + ) + + +if __name__ == "__main__": + absltest.main()