Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 123 additions & 114 deletions langextract/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@

from __future__ import annotations

import collections
from collections.abc import Iterable, Iterator
import itertools
import time
from typing import DefaultDict

from absl import logging

Expand All @@ -41,10 +42,6 @@
from langextract.core import format_handler as fh


class DocumentRepeatError(exceptions.LangExtractError):
"""Exception raised when identical document ids are present."""


def _merge_non_overlapping_extractions(
all_extractions: list[Iterable[data.Extraction]],
) -> list[data.Extraction]:
Expand Down Expand Up @@ -134,7 +131,7 @@ def _document_chunk_iterator(
TextChunk containing document ID for a corresponding document.

Raises:
DocumentRepeatError: If restrict_repeats is True and the same document ID
InvalidDocumentError: If restrict_repeats is True and the same document ID
is visited more than once. Valid documents prior to the error will be
returned.
"""
Expand All @@ -143,7 +140,7 @@ def _document_chunk_iterator(
tokenized_text = document.tokenized_text
document_id = document.document_id
if restrict_repeats and document_id in visited_ids:
raise DocumentRepeatError(
raise exceptions.InvalidDocumentError(
f"Document id {document_id} is already visited."
)
chunk_iter = chunking.ChunkIterator(
Expand Down Expand Up @@ -277,138 +274,143 @@ def _annotate_documents_single_pass(
show_progress: bool = True,
**kwargs,
) -> Iterator[data.AnnotatedDocument]:
"""Single-pass annotation logic (original implementation)."""
"""Single-pass annotation with stable ordering and streaming emission.

logging.info("Starting document annotation.")
doc_iter, doc_iter_for_chunks = itertools.tee(documents, 2)
curr_document = next(doc_iter, None)
if curr_document is None:
logging.warning("No documents to process.")
return

annotated_extractions: list[data.Extraction] = []
chunk_iter = _document_chunk_iterator(doc_iter_for_chunks, max_char_buffer)
Streams input without full materialization, maintains correct attribution
across batches, and emits completed documents immediately to minimize
peak memory usage. Handles generators from both infer() and align().
"""
doc_order: list[str] = []
doc_text_by_id: dict[str, str] = {}
per_doc: DefaultDict[str, list[data.Extraction]] = collections.defaultdict(
list
)
next_emit_idx = 0

def _capture_docs(src: Iterable[data.Document]) -> Iterator[data.Document]:
"""Captures document order and text lazily as chunks are produced."""
for document in src:
document_id = document.document_id
if document_id in doc_text_by_id:
raise exceptions.InvalidDocumentError(
f"Duplicate document_id: {document_id}"
)
doc_order.append(document_id)
doc_text_by_id[document_id] = document.text or ""
yield document

def _emit_docs_iter(
keep_last_doc: bool,
) -> Iterator[data.AnnotatedDocument]:
"""Yields documents that are guaranteed complete.

Args:
keep_last_doc: If True, retains the most recently started document
for additional extractions. If False, emits all remaining documents.
"""
nonlocal next_emit_idx
limit = max(0, len(doc_order) - 1) if keep_last_doc else len(doc_order)
while next_emit_idx < limit:
document_id = doc_order[next_emit_idx]
yield data.AnnotatedDocument(
document_id=document_id,
extractions=per_doc.get(document_id, []),
text=doc_text_by_id.get(document_id, ""),
)
per_doc.pop(document_id, None)
doc_text_by_id.pop(document_id, None)
next_emit_idx += 1

chunk_iter = _document_chunk_iterator(
_capture_docs(documents), max_char_buffer
)
batches = chunking.make_batches_of_textchunk(chunk_iter, batch_length)

model_info = progress.get_model_info(self._language_model)

progress_bar = progress.create_extraction_progress_bar(
batch_iter = progress.create_extraction_progress_bar(
batches, model_info=model_info, disable=not show_progress
)

chars_processed = 0

for index, batch in enumerate(progress_bar):
logging.info("Processing batch %d with length %d", index, len(batch))
try:
for batch in batch_iter:
if not batch:
continue

batch_prompts: list[str] = []
for text_chunk in batch:
batch_prompts.append(
prompts = [
self._prompt_generator.render(
question=text_chunk.chunk_text,
additional_context=text_chunk.additional_context,
)
)

# Show what we're currently processing
if debug and progress_bar:
batch_size = sum(len(chunk.chunk_text) for chunk in batch)
desc = progress.format_extraction_progress(
model_info,
current_chars=batch_size,
processed_chars=chars_processed,
)
progress_bar.set_description(desc)

batch_scored_outputs = self._language_model.infer(
batch_prompts=batch_prompts,
**kwargs,
)
for text_chunk in batch
]

# Update total processed
if debug:
for chunk in batch:
if chunk.document_text:
char_interval = chunk.char_interval
chars_processed += char_interval.end_pos - char_interval.start_pos

# Update progress bar with final processed count
if progress_bar:
batch_size = sum(len(chunk.chunk_text) for chunk in batch)
desc = progress.format_extraction_progress(
model_info,
current_chars=batch_size,
processed_chars=chars_processed,
if show_progress:
current_chars = sum(
len(text_chunk.chunk_text) for text_chunk in batch
)
progress_bar.set_description(desc)
try:
batch_iter.set_description(
progress.format_extraction_progress(
model_info,
current_chars=current_chars,
processed_chars=chars_processed,
)
)
except AttributeError:
pass

outputs = self._language_model.infer(batch_prompts=prompts, **kwargs)
if not isinstance(outputs, list):
outputs = list(outputs)

for text_chunk, scored_outputs in zip(batch, outputs):
if not isinstance(scored_outputs, list):
scored_outputs = list(scored_outputs)
if not scored_outputs:
raise exceptions.InferenceOutputError(
"No scored outputs from language model."
)

for text_chunk, scored_outputs in zip(batch, batch_scored_outputs):
logging.debug("Processing chunk: %s", text_chunk)
if not scored_outputs:
logging.error(
"No scored outputs for chunk with ID %s.", text_chunk.document_id
)
raise exceptions.InferenceOutputError(
"No scored outputs from language model."
resolved_extractions = resolver.resolve(
scored_outputs[0].output, debug=debug, **kwargs
)
while curr_document.document_id != text_chunk.document_id:
logging.info(
"Completing annotation for document ID %s.",
curr_document.document_id,

token_offset = (
text_chunk.token_interval.start_index
if text_chunk.token_interval
else 0
)
annotated_doc = data.AnnotatedDocument(
document_id=curr_document.document_id,
extractions=annotated_extractions,
text=curr_document.text,
char_offset = (
text_chunk.char_interval.start_pos
if text_chunk.char_interval
else 0
)
yield annotated_doc
annotated_extractions.clear()

curr_document = next(doc_iter, None)
assert curr_document is not None, (
f"Document should be defined for {text_chunk} per"
" _document_chunk_iterator(...) specifications."
aligned_extractions = resolver.align(
resolved_extractions,
text_chunk.chunk_text,
token_offset,
char_offset,
**kwargs,
)

top_inference_result = scored_outputs[0].output
logging.debug("Top inference result: %s", top_inference_result)
for extraction in aligned_extractions:
per_doc[text_chunk.document_id].append(extraction)

annotated_chunk_extractions = resolver.resolve(
top_inference_result, debug=debug, **kwargs
)
chunk_text = text_chunk.chunk_text
token_offset = text_chunk.token_interval.start_index
char_offset = text_chunk.char_interval.start_pos

aligned_extractions = resolver.align(
annotated_chunk_extractions,
chunk_text,
token_offset,
char_offset,
**kwargs,
)

annotated_extractions.extend(aligned_extractions)

progress_bar.close()

if debug:
progress.print_extraction_complete()
if show_progress and text_chunk.char_interval is not None:
chars_processed += (
text_chunk.char_interval.end_pos
- text_chunk.char_interval.start_pos
)

if curr_document is not None:
logging.info(
"Finalizing annotation for document ID %s.", curr_document.document_id
)
annotated_doc = data.AnnotatedDocument(
document_id=curr_document.document_id,
extractions=annotated_extractions,
text=curr_document.text,
)
yield from _emit_docs_iter(keep_last_doc=True)

yield annotated_doc
finally:
batch_iter.close()

logging.info("Document annotation completed.")
yield from _emit_docs_iter(keep_last_doc=False)

def _annotate_documents_sequential_passes(
self,
Expand All @@ -433,6 +435,10 @@ def _annotate_documents_sequential_passes(

document_extractions_by_pass: dict[str, list[list[data.Extraction]]] = {}
document_texts: dict[str, str] = {}
# Preserve text up-front so we can emit documents even if later passes
# produce no extractions.
for _doc in document_list:
document_texts[_doc.document_id] = _doc.text or ""

for pass_num in range(extraction_passes):
logging.info(
Expand All @@ -452,13 +458,16 @@ def _annotate_documents_sequential_passes(

if doc_id not in document_extractions_by_pass:
document_extractions_by_pass[doc_id] = []
document_texts[doc_id] = annotated_doc.text or ""
# Keep first-seen text (already pre-filled above).

document_extractions_by_pass[doc_id].append(
annotated_doc.extractions or []
)

for doc_id, all_pass_extractions in document_extractions_by_pass.items():
# Emit results strictly in original input order.
for doc in document_list:
doc_id = doc.document_id
all_pass_extractions = document_extractions_by_pass.get(doc_id, [])
merged_extractions = _merge_non_overlapping_extractions(
all_pass_extractions
)
Expand All @@ -479,7 +488,7 @@ def _annotate_documents_sequential_passes(
yield data.AnnotatedDocument(
document_id=doc_id,
extractions=merged_extractions,
text=document_texts[doc_id],
text=document_texts.get(doc_id, doc.text or ""),
)

logging.info("Sequential extraction passes completed.")
Expand Down
16 changes: 16 additions & 0 deletions langextract/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"InferenceConfigError",
"InferenceRuntimeError",
"InferenceOutputError",
"InternalError",
"InvalidDocumentError",
"ProviderError",
"SchemaError",
"FormatError",
Expand Down Expand Up @@ -88,6 +90,20 @@ def __init__(self, message: str):
super().__init__(self.message)


class InvalidDocumentError(LangExtractError):
"""Exception raised when document input is invalid.

This includes cases like duplicate document IDs or malformed documents.
"""


class InternalError(LangExtractError):
"""Exception raised for internal invariant violations.

This indicates a bug in LangExtract itself rather than user error.
"""


class ProviderError(LangExtractError):
"""Provider/backend specific error."""

Expand Down
Loading