Skip to content

Commit fe18abe

Browse files
aksg87vayoa
andauthored
fix: streamline annotation layer with lazy streaming (#276)
* test: add regression test for generator delegation bug Verify annotate_documents uses 'yield from' to properly delegate to generators, ensuring correct document attribution across batches. Co-authored-by: Vayoa <[email protected]> * fix: streamline annotation layer with lazy streaming Stream documents lazily and emit incrementally to reduce memory from O(documents) to O(batch_size). Improve code clarity with better naming (keep_last_doc, _emit_docs_iter) and removed verbose comments. * test: fix progress bar mock to pass through batches The test was not passing through the real batches, which prevented the lazy document capture from running. Updated mock to use side_effect to pass through the iterable while still allowing inspection of call args. --------- Co-authored-by: Vayoa <[email protected]>
1 parent bd1e3d2 commit fe18abe

File tree

4 files changed

+229
-122
lines changed

4 files changed

+229
-122
lines changed

langextract/annotation.py

Lines changed: 123 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525

2626
from __future__ import annotations
2727

28+
import collections
2829
from collections.abc import Iterable, Iterator
29-
import itertools
3030
import time
31+
from typing import DefaultDict
3132

3233
from absl import logging
3334

@@ -41,10 +42,6 @@
4142
from langextract.core import format_handler as fh
4243

4344

44-
class DocumentRepeatError(exceptions.LangExtractError):
45-
"""Exception raised when identical document ids are present."""
46-
47-
4845
def _merge_non_overlapping_extractions(
4946
all_extractions: list[Iterable[data.Extraction]],
5047
) -> list[data.Extraction]:
@@ -134,7 +131,7 @@ def _document_chunk_iterator(
134131
TextChunk containing document ID for a corresponding document.
135132
136133
Raises:
137-
DocumentRepeatError: If restrict_repeats is True and the same document ID
134+
InvalidDocumentError: If restrict_repeats is True and the same document ID
138135
is visited more than once. Valid documents prior to the error will be
139136
returned.
140137
"""
@@ -143,7 +140,7 @@ def _document_chunk_iterator(
143140
tokenized_text = document.tokenized_text
144141
document_id = document.document_id
145142
if restrict_repeats and document_id in visited_ids:
146-
raise DocumentRepeatError(
143+
raise exceptions.InvalidDocumentError(
147144
f"Document id {document_id} is already visited."
148145
)
149146
chunk_iter = chunking.ChunkIterator(
@@ -277,138 +274,143 @@ def _annotate_documents_single_pass(
277274
show_progress: bool = True,
278275
**kwargs,
279276
) -> Iterator[data.AnnotatedDocument]:
280-
"""Single-pass annotation logic (original implementation)."""
277+
"""Single-pass annotation with stable ordering and streaming emission.
281278
282-
logging.info("Starting document annotation.")
283-
doc_iter, doc_iter_for_chunks = itertools.tee(documents, 2)
284-
curr_document = next(doc_iter, None)
285-
if curr_document is None:
286-
logging.warning("No documents to process.")
287-
return
288-
289-
annotated_extractions: list[data.Extraction] = []
290-
chunk_iter = _document_chunk_iterator(doc_iter_for_chunks, max_char_buffer)
279+
Streams input without full materialization, maintains correct attribution
280+
across batches, and emits completed documents immediately to minimize
281+
peak memory usage. Handles generators from both infer() and align().
282+
"""
283+
doc_order: list[str] = []
284+
doc_text_by_id: dict[str, str] = {}
285+
per_doc: DefaultDict[str, list[data.Extraction]] = collections.defaultdict(
286+
list
287+
)
288+
next_emit_idx = 0
289+
290+
def _capture_docs(src: Iterable[data.Document]) -> Iterator[data.Document]:
291+
"""Captures document order and text lazily as chunks are produced."""
292+
for document in src:
293+
document_id = document.document_id
294+
if document_id in doc_text_by_id:
295+
raise exceptions.InvalidDocumentError(
296+
f"Duplicate document_id: {document_id}"
297+
)
298+
doc_order.append(document_id)
299+
doc_text_by_id[document_id] = document.text or ""
300+
yield document
301+
302+
def _emit_docs_iter(
303+
keep_last_doc: bool,
304+
) -> Iterator[data.AnnotatedDocument]:
305+
"""Yields documents that are guaranteed complete.
306+
307+
Args:
308+
keep_last_doc: If True, retains the most recently started document
309+
for additional extractions. If False, emits all remaining documents.
310+
"""
311+
nonlocal next_emit_idx
312+
limit = max(0, len(doc_order) - 1) if keep_last_doc else len(doc_order)
313+
while next_emit_idx < limit:
314+
document_id = doc_order[next_emit_idx]
315+
yield data.AnnotatedDocument(
316+
document_id=document_id,
317+
extractions=per_doc.get(document_id, []),
318+
text=doc_text_by_id.get(document_id, ""),
319+
)
320+
per_doc.pop(document_id, None)
321+
doc_text_by_id.pop(document_id, None)
322+
next_emit_idx += 1
291323

324+
chunk_iter = _document_chunk_iterator(
325+
_capture_docs(documents), max_char_buffer
326+
)
292327
batches = chunking.make_batches_of_textchunk(chunk_iter, batch_length)
293328

294329
model_info = progress.get_model_info(self._language_model)
295-
296-
progress_bar = progress.create_extraction_progress_bar(
330+
batch_iter = progress.create_extraction_progress_bar(
297331
batches, model_info=model_info, disable=not show_progress
298332
)
299333

300334
chars_processed = 0
301335

302-
for index, batch in enumerate(progress_bar):
303-
logging.info("Processing batch %d with length %d", index, len(batch))
336+
try:
337+
for batch in batch_iter:
338+
if not batch:
339+
continue
304340

305-
batch_prompts: list[str] = []
306-
for text_chunk in batch:
307-
batch_prompts.append(
341+
prompts = [
308342
self._prompt_generator.render(
309343
question=text_chunk.chunk_text,
310344
additional_context=text_chunk.additional_context,
311345
)
312-
)
313-
314-
# Show what we're currently processing
315-
if debug and progress_bar:
316-
batch_size = sum(len(chunk.chunk_text) for chunk in batch)
317-
desc = progress.format_extraction_progress(
318-
model_info,
319-
current_chars=batch_size,
320-
processed_chars=chars_processed,
321-
)
322-
progress_bar.set_description(desc)
323-
324-
batch_scored_outputs = self._language_model.infer(
325-
batch_prompts=batch_prompts,
326-
**kwargs,
327-
)
346+
for text_chunk in batch
347+
]
328348

329-
# Update total processed
330-
if debug:
331-
for chunk in batch:
332-
if chunk.document_text:
333-
char_interval = chunk.char_interval
334-
chars_processed += char_interval.end_pos - char_interval.start_pos
335-
336-
# Update progress bar with final processed count
337-
if progress_bar:
338-
batch_size = sum(len(chunk.chunk_text) for chunk in batch)
339-
desc = progress.format_extraction_progress(
340-
model_info,
341-
current_chars=batch_size,
342-
processed_chars=chars_processed,
349+
if show_progress:
350+
current_chars = sum(
351+
len(text_chunk.chunk_text) for text_chunk in batch
343352
)
344-
progress_bar.set_description(desc)
353+
try:
354+
batch_iter.set_description(
355+
progress.format_extraction_progress(
356+
model_info,
357+
current_chars=current_chars,
358+
processed_chars=chars_processed,
359+
)
360+
)
361+
except AttributeError:
362+
pass
363+
364+
outputs = self._language_model.infer(batch_prompts=prompts, **kwargs)
365+
if not isinstance(outputs, list):
366+
outputs = list(outputs)
367+
368+
for text_chunk, scored_outputs in zip(batch, outputs):
369+
if not isinstance(scored_outputs, list):
370+
scored_outputs = list(scored_outputs)
371+
if not scored_outputs:
372+
raise exceptions.InferenceOutputError(
373+
"No scored outputs from language model."
374+
)
345375

346-
for text_chunk, scored_outputs in zip(batch, batch_scored_outputs):
347-
logging.debug("Processing chunk: %s", text_chunk)
348-
if not scored_outputs:
349-
logging.error(
350-
"No scored outputs for chunk with ID %s.", text_chunk.document_id
351-
)
352-
raise exceptions.InferenceOutputError(
353-
"No scored outputs from language model."
376+
resolved_extractions = resolver.resolve(
377+
scored_outputs[0].output, debug=debug, **kwargs
354378
)
355-
while curr_document.document_id != text_chunk.document_id:
356-
logging.info(
357-
"Completing annotation for document ID %s.",
358-
curr_document.document_id,
379+
380+
token_offset = (
381+
text_chunk.token_interval.start_index
382+
if text_chunk.token_interval
383+
else 0
359384
)
360-
annotated_doc = data.AnnotatedDocument(
361-
document_id=curr_document.document_id,
362-
extractions=annotated_extractions,
363-
text=curr_document.text,
385+
char_offset = (
386+
text_chunk.char_interval.start_pos
387+
if text_chunk.char_interval
388+
else 0
364389
)
365-
yield annotated_doc
366-
annotated_extractions.clear()
367390

368-
curr_document = next(doc_iter, None)
369-
assert curr_document is not None, (
370-
f"Document should be defined for {text_chunk} per"
371-
" _document_chunk_iterator(...) specifications."
391+
aligned_extractions = resolver.align(
392+
resolved_extractions,
393+
text_chunk.chunk_text,
394+
token_offset,
395+
char_offset,
396+
**kwargs,
372397
)
373398

374-
top_inference_result = scored_outputs[0].output
375-
logging.debug("Top inference result: %s", top_inference_result)
399+
for extraction in aligned_extractions:
400+
per_doc[text_chunk.document_id].append(extraction)
376401

377-
annotated_chunk_extractions = resolver.resolve(
378-
top_inference_result, debug=debug, **kwargs
379-
)
380-
chunk_text = text_chunk.chunk_text
381-
token_offset = text_chunk.token_interval.start_index
382-
char_offset = text_chunk.char_interval.start_pos
383-
384-
aligned_extractions = resolver.align(
385-
annotated_chunk_extractions,
386-
chunk_text,
387-
token_offset,
388-
char_offset,
389-
**kwargs,
390-
)
391-
392-
annotated_extractions.extend(aligned_extractions)
393-
394-
progress_bar.close()
395-
396-
if debug:
397-
progress.print_extraction_complete()
402+
if show_progress and text_chunk.char_interval is not None:
403+
chars_processed += (
404+
text_chunk.char_interval.end_pos
405+
- text_chunk.char_interval.start_pos
406+
)
398407

399-
if curr_document is not None:
400-
logging.info(
401-
"Finalizing annotation for document ID %s.", curr_document.document_id
402-
)
403-
annotated_doc = data.AnnotatedDocument(
404-
document_id=curr_document.document_id,
405-
extractions=annotated_extractions,
406-
text=curr_document.text,
407-
)
408+
yield from _emit_docs_iter(keep_last_doc=True)
408409

409-
yield annotated_doc
410+
finally:
411+
batch_iter.close()
410412

411-
logging.info("Document annotation completed.")
413+
yield from _emit_docs_iter(keep_last_doc=False)
412414

413415
def _annotate_documents_sequential_passes(
414416
self,
@@ -433,6 +435,10 @@ def _annotate_documents_sequential_passes(
433435

434436
document_extractions_by_pass: dict[str, list[list[data.Extraction]]] = {}
435437
document_texts: dict[str, str] = {}
438+
# Preserve text up-front so we can emit documents even if later passes
439+
# produce no extractions.
440+
for _doc in document_list:
441+
document_texts[_doc.document_id] = _doc.text or ""
436442

437443
for pass_num in range(extraction_passes):
438444
logging.info(
@@ -452,13 +458,16 @@ def _annotate_documents_sequential_passes(
452458

453459
if doc_id not in document_extractions_by_pass:
454460
document_extractions_by_pass[doc_id] = []
455-
document_texts[doc_id] = annotated_doc.text or ""
461+
# Keep first-seen text (already pre-filled above).
456462

457463
document_extractions_by_pass[doc_id].append(
458464
annotated_doc.extractions or []
459465
)
460466

461-
for doc_id, all_pass_extractions in document_extractions_by_pass.items():
467+
# Emit results strictly in original input order.
468+
for doc in document_list:
469+
doc_id = doc.document_id
470+
all_pass_extractions = document_extractions_by_pass.get(doc_id, [])
462471
merged_extractions = _merge_non_overlapping_extractions(
463472
all_pass_extractions
464473
)
@@ -479,7 +488,7 @@ def _annotate_documents_sequential_passes(
479488
yield data.AnnotatedDocument(
480489
document_id=doc_id,
481490
extractions=merged_extractions,
482-
text=document_texts[doc_id],
491+
text=document_texts.get(doc_id, doc.text or ""),
483492
)
484493

485494
logging.info("Sequential extraction passes completed.")

langextract/core/exceptions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
"InferenceConfigError",
2727
"InferenceRuntimeError",
2828
"InferenceOutputError",
29+
"InternalError",
30+
"InvalidDocumentError",
2931
"ProviderError",
3032
"SchemaError",
3133
"FormatError",
@@ -88,6 +90,20 @@ def __init__(self, message: str):
8890
super().__init__(self.message)
8991

9092

93+
class InvalidDocumentError(LangExtractError):
94+
"""Exception raised when document input is invalid.
95+
96+
This includes cases like duplicate document IDs or malformed documents.
97+
"""
98+
99+
100+
class InternalError(LangExtractError):
101+
"""Exception raised for internal invariant violations.
102+
103+
This indicates a bug in LangExtract itself rather than user error.
104+
"""
105+
106+
91107
class ProviderError(LangExtractError):
92108
"""Provider/backend specific error."""
93109

0 commit comments

Comments
 (0)