diff --git a/langextract/annotation.py b/langextract/annotation.py index 4f0ec7bd..cdd149df 100644 --- a/langextract/annotation.py +++ b/langextract/annotation.py @@ -25,9 +25,10 @@ from __future__ import annotations +import collections from collections.abc import Iterable, Iterator -import itertools import time +from typing import DefaultDict from absl import logging @@ -41,10 +42,6 @@ from langextract.core import format_handler as fh -class DocumentRepeatError(exceptions.LangExtractError): - """Exception raised when identical document ids are present.""" - - def _merge_non_overlapping_extractions( all_extractions: list[Iterable[data.Extraction]], ) -> list[data.Extraction]: @@ -134,7 +131,7 @@ def _document_chunk_iterator( TextChunk containing document ID for a corresponding document. Raises: - DocumentRepeatError: If restrict_repeats is True and the same document ID + InvalidDocumentError: If restrict_repeats is True and the same document ID is visited more than once. Valid documents prior to the error will be returned. """ @@ -143,7 +140,7 @@ def _document_chunk_iterator( tokenized_text = document.tokenized_text document_id = document.document_id if restrict_repeats and document_id in visited_ids: - raise DocumentRepeatError( + raise exceptions.InvalidDocumentError( f"Document id {document_id} is already visited." ) chunk_iter = chunking.ChunkIterator( @@ -277,138 +274,143 @@ def _annotate_documents_single_pass( show_progress: bool = True, **kwargs, ) -> Iterator[data.AnnotatedDocument]: - """Single-pass annotation logic (original implementation).""" + """Single-pass annotation with stable ordering and streaming emission. - logging.info("Starting document annotation.") - doc_iter, doc_iter_for_chunks = itertools.tee(documents, 2) - curr_document = next(doc_iter, None) - if curr_document is None: - logging.warning("No documents to process.") - return - - annotated_extractions: list[data.Extraction] = [] - chunk_iter = _document_chunk_iterator(doc_iter_for_chunks, max_char_buffer) + Streams input without full materialization, maintains correct attribution + across batches, and emits completed documents immediately to minimize + peak memory usage. Handles generators from both infer() and align(). + """ + doc_order: list[str] = [] + doc_text_by_id: dict[str, str] = {} + per_doc: DefaultDict[str, list[data.Extraction]] = collections.defaultdict( + list + ) + next_emit_idx = 0 + + def _capture_docs(src: Iterable[data.Document]) -> Iterator[data.Document]: + """Captures document order and text lazily as chunks are produced.""" + for document in src: + document_id = document.document_id + if document_id in doc_text_by_id: + raise exceptions.InvalidDocumentError( + f"Duplicate document_id: {document_id}" + ) + doc_order.append(document_id) + doc_text_by_id[document_id] = document.text or "" + yield document + + def _emit_docs_iter( + keep_last_doc: bool, + ) -> Iterator[data.AnnotatedDocument]: + """Yields documents that are guaranteed complete. + + Args: + keep_last_doc: If True, retains the most recently started document + for additional extractions. If False, emits all remaining documents. + """ + nonlocal next_emit_idx + limit = max(0, len(doc_order) - 1) if keep_last_doc else len(doc_order) + while next_emit_idx < limit: + document_id = doc_order[next_emit_idx] + yield data.AnnotatedDocument( + document_id=document_id, + extractions=per_doc.get(document_id, []), + text=doc_text_by_id.get(document_id, ""), + ) + per_doc.pop(document_id, None) + doc_text_by_id.pop(document_id, None) + next_emit_idx += 1 + chunk_iter = _document_chunk_iterator( + _capture_docs(documents), max_char_buffer + ) batches = chunking.make_batches_of_textchunk(chunk_iter, batch_length) model_info = progress.get_model_info(self._language_model) - - progress_bar = progress.create_extraction_progress_bar( + batch_iter = progress.create_extraction_progress_bar( batches, model_info=model_info, disable=not show_progress ) chars_processed = 0 - for index, batch in enumerate(progress_bar): - logging.info("Processing batch %d with length %d", index, len(batch)) + try: + for batch in batch_iter: + if not batch: + continue - batch_prompts: list[str] = [] - for text_chunk in batch: - batch_prompts.append( + prompts = [ self._prompt_generator.render( question=text_chunk.chunk_text, additional_context=text_chunk.additional_context, ) - ) - - # Show what we're currently processing - if debug and progress_bar: - batch_size = sum(len(chunk.chunk_text) for chunk in batch) - desc = progress.format_extraction_progress( - model_info, - current_chars=batch_size, - processed_chars=chars_processed, - ) - progress_bar.set_description(desc) - - batch_scored_outputs = self._language_model.infer( - batch_prompts=batch_prompts, - **kwargs, - ) + for text_chunk in batch + ] - # Update total processed - if debug: - for chunk in batch: - if chunk.document_text: - char_interval = chunk.char_interval - chars_processed += char_interval.end_pos - char_interval.start_pos - - # Update progress bar with final processed count - if progress_bar: - batch_size = sum(len(chunk.chunk_text) for chunk in batch) - desc = progress.format_extraction_progress( - model_info, - current_chars=batch_size, - processed_chars=chars_processed, + if show_progress: + current_chars = sum( + len(text_chunk.chunk_text) for text_chunk in batch ) - progress_bar.set_description(desc) + try: + batch_iter.set_description( + progress.format_extraction_progress( + model_info, + current_chars=current_chars, + processed_chars=chars_processed, + ) + ) + except AttributeError: + pass + + outputs = self._language_model.infer(batch_prompts=prompts, **kwargs) + if not isinstance(outputs, list): + outputs = list(outputs) + + for text_chunk, scored_outputs in zip(batch, outputs): + if not isinstance(scored_outputs, list): + scored_outputs = list(scored_outputs) + if not scored_outputs: + raise exceptions.InferenceOutputError( + "No scored outputs from language model." + ) - for text_chunk, scored_outputs in zip(batch, batch_scored_outputs): - logging.debug("Processing chunk: %s", text_chunk) - if not scored_outputs: - logging.error( - "No scored outputs for chunk with ID %s.", text_chunk.document_id - ) - raise exceptions.InferenceOutputError( - "No scored outputs from language model." + resolved_extractions = resolver.resolve( + scored_outputs[0].output, debug=debug, **kwargs ) - while curr_document.document_id != text_chunk.document_id: - logging.info( - "Completing annotation for document ID %s.", - curr_document.document_id, + + token_offset = ( + text_chunk.token_interval.start_index + if text_chunk.token_interval + else 0 ) - annotated_doc = data.AnnotatedDocument( - document_id=curr_document.document_id, - extractions=annotated_extractions, - text=curr_document.text, + char_offset = ( + text_chunk.char_interval.start_pos + if text_chunk.char_interval + else 0 ) - yield annotated_doc - annotated_extractions.clear() - curr_document = next(doc_iter, None) - assert curr_document is not None, ( - f"Document should be defined for {text_chunk} per" - " _document_chunk_iterator(...) specifications." + aligned_extractions = resolver.align( + resolved_extractions, + text_chunk.chunk_text, + token_offset, + char_offset, + **kwargs, ) - top_inference_result = scored_outputs[0].output - logging.debug("Top inference result: %s", top_inference_result) + for extraction in aligned_extractions: + per_doc[text_chunk.document_id].append(extraction) - annotated_chunk_extractions = resolver.resolve( - top_inference_result, debug=debug, **kwargs - ) - chunk_text = text_chunk.chunk_text - token_offset = text_chunk.token_interval.start_index - char_offset = text_chunk.char_interval.start_pos - - aligned_extractions = resolver.align( - annotated_chunk_extractions, - chunk_text, - token_offset, - char_offset, - **kwargs, - ) - - annotated_extractions.extend(aligned_extractions) - - progress_bar.close() - - if debug: - progress.print_extraction_complete() + if show_progress and text_chunk.char_interval is not None: + chars_processed += ( + text_chunk.char_interval.end_pos + - text_chunk.char_interval.start_pos + ) - if curr_document is not None: - logging.info( - "Finalizing annotation for document ID %s.", curr_document.document_id - ) - annotated_doc = data.AnnotatedDocument( - document_id=curr_document.document_id, - extractions=annotated_extractions, - text=curr_document.text, - ) + yield from _emit_docs_iter(keep_last_doc=True) - yield annotated_doc + finally: + batch_iter.close() - logging.info("Document annotation completed.") + yield from _emit_docs_iter(keep_last_doc=False) def _annotate_documents_sequential_passes( self, @@ -433,6 +435,10 @@ def _annotate_documents_sequential_passes( document_extractions_by_pass: dict[str, list[list[data.Extraction]]] = {} document_texts: dict[str, str] = {} + # Preserve text up-front so we can emit documents even if later passes + # produce no extractions. + for _doc in document_list: + document_texts[_doc.document_id] = _doc.text or "" for pass_num in range(extraction_passes): logging.info( @@ -452,13 +458,16 @@ def _annotate_documents_sequential_passes( if doc_id not in document_extractions_by_pass: document_extractions_by_pass[doc_id] = [] - document_texts[doc_id] = annotated_doc.text or "" + # Keep first-seen text (already pre-filled above). document_extractions_by_pass[doc_id].append( annotated_doc.extractions or [] ) - for doc_id, all_pass_extractions in document_extractions_by_pass.items(): + # Emit results strictly in original input order. + for doc in document_list: + doc_id = doc.document_id + all_pass_extractions = document_extractions_by_pass.get(doc_id, []) merged_extractions = _merge_non_overlapping_extractions( all_pass_extractions ) @@ -479,7 +488,7 @@ def _annotate_documents_sequential_passes( yield data.AnnotatedDocument( document_id=doc_id, extractions=merged_extractions, - text=document_texts[doc_id], + text=document_texts.get(doc_id, doc.text or ""), ) logging.info("Sequential extraction passes completed.") diff --git a/langextract/core/exceptions.py b/langextract/core/exceptions.py index 5f639e99..422d62b1 100644 --- a/langextract/core/exceptions.py +++ b/langextract/core/exceptions.py @@ -26,6 +26,8 @@ "InferenceConfigError", "InferenceRuntimeError", "InferenceOutputError", + "InternalError", + "InvalidDocumentError", "ProviderError", "SchemaError", "FormatError", @@ -88,6 +90,20 @@ def __init__(self, message: str): super().__init__(self.message) +class InvalidDocumentError(LangExtractError): + """Exception raised when document input is invalid. + + This includes cases like duplicate document IDs or malformed documents. + """ + + +class InternalError(LangExtractError): + """Exception raised for internal invariant violations. + + This indicates a bug in LangExtract itself rather than user error. + """ + + class ProviderError(LangExtractError): """Provider/backend specific error.""" diff --git a/tests/annotation_test.py b/tests/annotation_test.py index 78df7dd6..6cca1e7f 100644 --- a/tests/annotation_test.py +++ b/tests/annotation_test.py @@ -14,6 +14,7 @@ from collections.abc import Sequence import dataclasses +import inspect import textwrap from typing import Type from unittest import mock @@ -25,6 +26,7 @@ from langextract import prompting from langextract import resolver as resolver_lib from langextract.core import data +from langextract.core import exceptions from langextract.core import tokenizer from langextract.core import types from langextract.providers import gemini @@ -749,7 +751,7 @@ def mock_infer_side_effect(batch_prompts, **kwargs): {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, ], - expected_exception=annotation.DocumentRepeatError, + expected_exception=exceptions.InvalidDocumentError, ), dict( testcase_name="same_document_id_separated", @@ -758,13 +760,13 @@ def mock_infer_side_effect(batch_prompts, **kwargs): {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc2"}, {"text": _FIXED_DOCUMENT_CONTENT, "document_id": "doc1"}, ], - expected_exception=annotation.DocumentRepeatError, + expected_exception=exceptions.InvalidDocumentError, ), ) def test_annotate_documents_exceptions( self, documents: Sequence[dict[str, str]], - expected_exception: Type[annotation.DocumentRepeatError], + expected_exception: Type[exceptions.InvalidDocumentError], batch_length: int = 1, ): mock_language_model = self.enter_context( @@ -1118,5 +1120,89 @@ def test_extractions_overlap(self, ext1, ext2, expected): self.assertEqual(result, expected) +class AnnotateDocumentsGeneratorTest(absltest.TestCase): + """Tests that annotate_documents uses 'yield from' for proper delegation.""" + + def setUp(self): + super().setUp() + self.mock_language_model = self.enter_context( + mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) + ) + + def mock_infer(batch_prompts, **_): + """Return medication extractions based on prompt content.""" + for prompt in batch_prompts: + if "Ibuprofen" in prompt: + text = textwrap.dedent(f"""\ + ```yaml + {data.EXTRACTIONS_KEY}: + - medication: "Ibuprofen" + medication_index: 4 + ```""") + elif "Cefazolin" in prompt: + text = textwrap.dedent(f"""\ + ```yaml + {data.EXTRACTIONS_KEY}: + - medication: "Cefazolin" + medication_index: 4 + ```""") + else: + text = f"```yaml\n{data.EXTRACTIONS_KEY}: []\n```" + yield [types.ScoredOutput(score=1.0, output=text)] + + self.mock_language_model.infer.side_effect = mock_infer + + self.annotator = annotation.Annotator( + language_model=self.mock_language_model, + prompt_template=prompting.PromptTemplateStructured(description=""), + ) + + def test_yields_documents_not_generators(self): + """Verifies annotate_documents yields AnnotatedDocument, not generators.""" + docs = [ + data.Document( + text="Patient took 400 mg PO Ibuprofen q4h for two days.", + document_id="doc1", + ), + data.Document( + text="Patient was given 250 mg IV Cefazolin TID for one week.", + document_id="doc2", + ), + ] + + results = list( + self.annotator.annotate_documents( + docs, + resolver=resolver_lib.Resolver( + fence_output=True, + format_type=data.FormatType.YAML, + extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, + ), + show_progress=False, + debug=False, + ) + ) + + self.assertLen(results, 2) + self.assertFalse( + any(inspect.isgenerator(item) for item in results), + msg="Must use 'yield from' to delegate, not 'yield'", + ) + meds_doc1 = { + e.extraction_text + for e in results[0].extractions + if e.extraction_class == "medication" + } + meds_doc2 = { + e.extraction_text + for e in results[1].extractions + if e.extraction_class == "medication" + } + self.assertIn("Ibuprofen", meds_doc1) + self.assertNotIn("Cefazolin", meds_doc1) + self.assertIn("Cefazolin", meds_doc2) + self.assertNotIn("Ibuprofen", meds_doc2) + + if __name__ == "__main__": absltest.main() diff --git a/tests/init_test.py b/tests/init_test.py index 98827b7f..d8cbdd5a 100644 --- a/tests/init_test.py +++ b/tests/init_test.py @@ -555,11 +555,7 @@ def test_show_progress_controls_progress_bar( mock_model.requires_fence_output = False mock_create_model.return_value = mock_model - mock_progress_bar = mock.MagicMock() - mock_progress_bar.__iter__ = mock.MagicMock( - return_value=iter([mock.MagicMock()]) - ) - mock_progress.return_value = mock_progress_bar + mock_progress.side_effect = lambda iterable, **kwargs: iter(iterable) mock_examples = [ lx.data.ExampleData(