2525
2626from __future__ import annotations
2727
28+ import collections
2829from collections .abc import Iterable , Iterator
29- import itertools
3030import time
31+ from typing import DefaultDict
3132
3233from absl import logging
3334
4142from langextract .core import format_handler as fh
4243
4344
44- class DocumentRepeatError (exceptions .LangExtractError ):
45- """Exception raised when identical document ids are present."""
46-
47-
4845def _merge_non_overlapping_extractions (
4946 all_extractions : list [Iterable [data .Extraction ]],
5047) -> list [data .Extraction ]:
@@ -134,7 +131,7 @@ def _document_chunk_iterator(
134131 TextChunk containing document ID for a corresponding document.
135132
136133 Raises:
137- DocumentRepeatError : If restrict_repeats is True and the same document ID
134+ InvalidDocumentError : If restrict_repeats is True and the same document ID
138135 is visited more than once. Valid documents prior to the error will be
139136 returned.
140137 """
@@ -143,7 +140,7 @@ def _document_chunk_iterator(
143140 tokenized_text = document .tokenized_text
144141 document_id = document .document_id
145142 if restrict_repeats and document_id in visited_ids :
146- raise DocumentRepeatError (
143+ raise exceptions . InvalidDocumentError (
147144 f"Document id { document_id } is already visited."
148145 )
149146 chunk_iter = chunking .ChunkIterator (
@@ -277,138 +274,149 @@ def _annotate_documents_single_pass(
277274 show_progress : bool = True ,
278275 ** kwargs ,
279276 ) -> Iterator [data .AnnotatedDocument ]:
280- """Single-pass annotation logic (original implementation)."""
277+ """Single-pass annotation with stable ordering and streaming emission.
281278
282- logging .info ("Starting document annotation." )
283- doc_iter , doc_iter_for_chunks = itertools .tee (documents , 2 )
284- curr_document = next (doc_iter , None )
285- if curr_document is None :
286- logging .warning ("No documents to process." )
287- return
288-
289- annotated_extractions : list [data .Extraction ] = []
290- chunk_iter = _document_chunk_iterator (doc_iter_for_chunks , max_char_buffer )
279+ Streams input without full materialization, maintains correct attribution
280+ across batches, and emits completed documents immediately to minimize
281+ peak memory usage. Handles generators from both infer() and align().
282+ """
283+ doc_order : list [str ] = []
284+ doc_text_by_id : dict [str , str ] = {}
285+ per_doc : DefaultDict [str , list [data .Extraction ]] = collections .defaultdict (
286+ list
287+ )
288+ next_emit_idx = 0
289+
290+ def _capture_docs (src : Iterable [data .Document ]) -> Iterator [data .Document ]:
291+ """Captures document order and text lazily as chunks are produced."""
292+ for document in src :
293+ document_id = document .document_id
294+ if document_id in doc_text_by_id :
295+ raise exceptions .InvalidDocumentError (
296+ f"Duplicate document_id: { document_id } "
297+ )
298+ doc_order .append (document_id )
299+ doc_text_by_id [document_id ] = document .text or ""
300+ yield document
301+
302+ def _emit_docs_iter (
303+ keep_last_doc : bool ,
304+ ) -> Iterator [data .AnnotatedDocument ]:
305+ """Yields documents that are guaranteed complete.
306+
307+ Args:
308+ keep_last_doc: If True, retains the most recently started document
309+ for additional extractions. If False, emits all remaining documents.
310+ """
311+ nonlocal next_emit_idx
312+ limit = max (0 , len (doc_order ) - 1 ) if keep_last_doc else len (doc_order )
313+ while next_emit_idx < limit :
314+ document_id = doc_order [next_emit_idx ]
315+ yield data .AnnotatedDocument (
316+ document_id = document_id ,
317+ extractions = per_doc .get (document_id , []),
318+ text = doc_text_by_id .get (document_id , "" ),
319+ )
320+ per_doc .pop (document_id , None )
321+ doc_text_by_id .pop (document_id , None )
322+ next_emit_idx += 1
291323
324+ chunk_iter = _document_chunk_iterator (
325+ _capture_docs (documents ), max_char_buffer
326+ )
292327 batches = chunking .make_batches_of_textchunk (chunk_iter , batch_length )
293328
294329 model_info = progress .get_model_info (self ._language_model )
295-
296- progress_bar = progress .create_extraction_progress_bar (
330+ batch_iter = progress .create_extraction_progress_bar (
297331 batches , model_info = model_info , disable = not show_progress
298332 )
299333
300334 chars_processed = 0
301335
302- for index , batch in enumerate (progress_bar ):
303- logging .info ("Processing batch %d with length %d" , index , len (batch ))
336+ try :
337+ for batch in batch_iter :
338+ if not batch :
339+ continue
304340
305- batch_prompts : list [str ] = []
306- for text_chunk in batch :
307- batch_prompts .append (
341+ prompts = [
308342 self ._prompt_generator .render (
309343 question = text_chunk .chunk_text ,
310344 additional_context = text_chunk .additional_context ,
311345 )
312- )
346+ for text_chunk in batch
347+ ]
313348
314- # Show what we're currently processing
315- if debug and progress_bar :
316- batch_size = sum (len (chunk .chunk_text ) for chunk in batch )
317- desc = progress .format_extraction_progress (
318- model_info ,
319- current_chars = batch_size ,
320- processed_chars = chars_processed ,
321- )
322- progress_bar .set_description (desc )
349+ if show_progress :
350+ current_chars = sum (
351+ len (text_chunk .chunk_text ) for text_chunk in batch
352+ )
353+ try :
354+ batch_iter .set_description (
355+ progress .format_extraction_progress (
356+ model_info ,
357+ current_chars = current_chars ,
358+ processed_chars = chars_processed ,
359+ )
360+ )
361+ except AttributeError :
362+ pass
323363
324- batch_scored_outputs = self ._language_model .infer (
325- batch_prompts = batch_prompts ,
326- ** kwargs ,
327- )
364+ outputs = self ._language_model .infer (batch_prompts = prompts , ** kwargs )
365+ if not isinstance (outputs , list ):
366+ outputs = list (outputs )
328367
329- # Update total processed
330- if debug :
331- for chunk in batch :
332- if chunk .document_text :
333- char_interval = chunk .char_interval
334- chars_processed += char_interval .end_pos - char_interval .start_pos
335-
336- # Update progress bar with final processed count
337- if progress_bar :
338- batch_size = sum (len (chunk .chunk_text ) for chunk in batch )
339- desc = progress .format_extraction_progress (
340- model_info ,
341- current_chars = batch_size ,
342- processed_chars = chars_processed ,
368+ if len (outputs ) != len (batch ):
369+ raise exceptions .InferenceOutputError (
370+ f"Language model returned { len (outputs )} outputs for"
371+ f" { len (batch )} prompts."
343372 )
344- progress_bar .set_description (desc )
345373
346- for text_chunk , scored_outputs in zip (batch , batch_scored_outputs ):
347- logging .debug ("Processing chunk: %s" , text_chunk )
348- if not scored_outputs :
349- logging .error (
350- "No scored outputs for chunk with ID %s." , text_chunk .document_id
351- )
352- raise exceptions .InferenceOutputError (
353- "No scored outputs from language model."
374+ for text_chunk , scored_outputs in zip (batch , outputs ):
375+ if not isinstance (scored_outputs , list ):
376+ scored_outputs = list (scored_outputs )
377+ if not scored_outputs :
378+ raise exceptions .InferenceOutputError (
379+ "No scored outputs from language model."
380+ )
381+
382+ resolved_extractions = resolver .resolve (
383+ scored_outputs [0 ].output , debug = debug , ** kwargs
354384 )
355- while curr_document .document_id != text_chunk .document_id :
356- logging .info (
357- "Completing annotation for document ID %s." ,
358- curr_document .document_id ,
385+
386+ token_offset = (
387+ text_chunk .token_interval .start_index
388+ if text_chunk .token_interval
389+ else 0
359390 )
360- annotated_doc = data . AnnotatedDocument (
361- document_id = curr_document . document_id ,
362- extractions = annotated_extractions ,
363- text = curr_document . text ,
391+ char_offset = (
392+ text_chunk . char_interval . start_pos
393+ if text_chunk . char_interval
394+ else 0
364395 )
365- yield annotated_doc
366- annotated_extractions .clear ()
367396
368- curr_document = next (doc_iter , None )
369- assert curr_document is not None , (
370- f"Document should be defined for { text_chunk } per"
371- " _document_chunk_iterator(...) specifications."
397+ aligned_extractions = resolver .align (
398+ resolved_extractions ,
399+ text_chunk .chunk_text ,
400+ token_offset ,
401+ char_offset ,
402+ ** kwargs ,
372403 )
373404
374- top_inference_result = scored_outputs [0 ].output
375- logging .debug ("Top inference result: %s" , top_inference_result )
376-
377- annotated_chunk_extractions = resolver .resolve (
378- top_inference_result , debug = debug , ** kwargs
379- )
380- chunk_text = text_chunk .chunk_text
381- token_offset = text_chunk .token_interval .start_index
382- char_offset = text_chunk .char_interval .start_pos
383-
384- aligned_extractions = resolver .align (
385- annotated_chunk_extractions ,
386- chunk_text ,
387- token_offset ,
388- char_offset ,
389- ** kwargs ,
390- )
391-
392- annotated_extractions .extend (aligned_extractions )
393-
394- progress_bar .close ()
405+ for extraction in aligned_extractions :
406+ per_doc [text_chunk .document_id ].append (extraction )
395407
396- if debug :
397- progress .print_extraction_complete ()
408+ if show_progress and text_chunk .char_interval is not None :
409+ chars_processed += (
410+ text_chunk .char_interval .end_pos
411+ - text_chunk .char_interval .start_pos
412+ )
398413
399- if curr_document is not None :
400- logging .info (
401- "Finalizing annotation for document ID %s." , curr_document .document_id
402- )
403- annotated_doc = data .AnnotatedDocument (
404- document_id = curr_document .document_id ,
405- extractions = annotated_extractions ,
406- text = curr_document .text ,
407- )
414+ yield from _emit_docs_iter (keep_last_doc = True )
408415
409- yield annotated_doc
416+ finally :
417+ batch_iter .close ()
410418
411- logging . info ( "Document annotation completed." )
419+ yield from _emit_docs_iter ( keep_last_doc = False )
412420
413421 def _annotate_documents_sequential_passes (
414422 self ,
@@ -433,6 +441,10 @@ def _annotate_documents_sequential_passes(
433441
434442 document_extractions_by_pass : dict [str , list [list [data .Extraction ]]] = {}
435443 document_texts : dict [str , str ] = {}
444+ # Preserve text up-front so we can emit documents even if later passes
445+ # produce no extractions.
446+ for _doc in document_list :
447+ document_texts [_doc .document_id ] = _doc .text or ""
436448
437449 for pass_num in range (extraction_passes ):
438450 logging .info (
@@ -452,13 +464,16 @@ def _annotate_documents_sequential_passes(
452464
453465 if doc_id not in document_extractions_by_pass :
454466 document_extractions_by_pass [doc_id ] = []
455- document_texts [ doc_id ] = annotated_doc . text or ""
467+ # Keep first-seen text (already pre-filled above).
456468
457469 document_extractions_by_pass [doc_id ].append (
458470 annotated_doc .extractions or []
459471 )
460472
461- for doc_id , all_pass_extractions in document_extractions_by_pass .items ():
473+ # Emit results strictly in original input order.
474+ for doc in document_list :
475+ doc_id = doc .document_id
476+ all_pass_extractions = document_extractions_by_pass .get (doc_id , [])
462477 merged_extractions = _merge_non_overlapping_extractions (
463478 all_pass_extractions
464479 )
@@ -479,7 +494,7 @@ def _annotate_documents_sequential_passes(
479494 yield data .AnnotatedDocument (
480495 document_id = doc_id ,
481496 extractions = merged_extractions ,
482- text = document_texts [ doc_id ] ,
497+ text = document_texts . get ( doc_id , doc . text or "" ) ,
483498 )
484499
485500 logging .info ("Sequential extraction passes completed." )
0 commit comments