Skip to content

Commit c598ca9

Browse files
aksg87vayoa
andcommitted
test: add regression test for generator delegation bug
Verify annotate_documents uses 'yield from' to properly delegate to generators, ensuring correct document attribution across batches. Co-authored-by: Vayoa <[email protected]>
1 parent bd1e3d2 commit c598ca9

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

tests/annotation_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from collections.abc import Sequence
1616
import dataclasses
17+
import inspect
1718
import textwrap
1819
from typing import Type
1920
from unittest import mock
@@ -1118,5 +1119,89 @@ def test_extractions_overlap(self, ext1, ext2, expected):
11181119
self.assertEqual(result, expected)
11191120

11201121

1122+
class AnnotateDocumentsGeneratorTest(absltest.TestCase):
1123+
"""Tests that annotate_documents uses 'yield from' for proper delegation."""
1124+
1125+
def setUp(self):
1126+
super().setUp()
1127+
self.mock_language_model = self.enter_context(
1128+
mock.patch.object(gemini, "GeminiLanguageModel", autospec=True)
1129+
)
1130+
1131+
def mock_infer(batch_prompts, **_):
1132+
"""Return medication extractions based on prompt content."""
1133+
for prompt in batch_prompts:
1134+
if "Ibuprofen" in prompt:
1135+
text = textwrap.dedent(f"""\
1136+
```yaml
1137+
{data.EXTRACTIONS_KEY}:
1138+
- medication: "Ibuprofen"
1139+
medication_index: 4
1140+
```""")
1141+
elif "Cefazolin" in prompt:
1142+
text = textwrap.dedent(f"""\
1143+
```yaml
1144+
{data.EXTRACTIONS_KEY}:
1145+
- medication: "Cefazolin"
1146+
medication_index: 4
1147+
```""")
1148+
else:
1149+
text = f"```yaml\n{data.EXTRACTIONS_KEY}: []\n```"
1150+
yield [types.ScoredOutput(score=1.0, output=text)]
1151+
1152+
self.mock_language_model.infer.side_effect = mock_infer
1153+
1154+
self.annotator = annotation.Annotator(
1155+
language_model=self.mock_language_model,
1156+
prompt_template=prompting.PromptTemplateStructured(description=""),
1157+
)
1158+
1159+
def test_yields_documents_not_generators(self):
1160+
"""Verifies annotate_documents yields AnnotatedDocument, not generators."""
1161+
docs = [
1162+
data.Document(
1163+
text="Patient took 400 mg PO Ibuprofen q4h for two days.",
1164+
document_id="doc1",
1165+
),
1166+
data.Document(
1167+
text="Patient was given 250 mg IV Cefazolin TID for one week.",
1168+
document_id="doc2",
1169+
),
1170+
]
1171+
1172+
results = list(
1173+
self.annotator.annotate_documents(
1174+
docs,
1175+
resolver=resolver_lib.Resolver(
1176+
fence_output=True,
1177+
format_type=data.FormatType.YAML,
1178+
extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX,
1179+
),
1180+
show_progress=False,
1181+
debug=False,
1182+
)
1183+
)
1184+
1185+
self.assertLen(results, 2)
1186+
self.assertFalse(
1187+
any(inspect.isgenerator(item) for item in results),
1188+
msg="Must use 'yield from' to delegate, not 'yield'",
1189+
)
1190+
meds_doc1 = {
1191+
e.extraction_text
1192+
for e in results[0].extractions
1193+
if e.extraction_class == "medication"
1194+
}
1195+
meds_doc2 = {
1196+
e.extraction_text
1197+
for e in results[1].extractions
1198+
if e.extraction_class == "medication"
1199+
}
1200+
self.assertIn("Ibuprofen", meds_doc1)
1201+
self.assertNotIn("Cefazolin", meds_doc1)
1202+
self.assertIn("Cefazolin", meds_doc2)
1203+
self.assertNotIn("Ibuprofen", meds_doc2)
1204+
1205+
11211206
if __name__ == "__main__":
11221207
absltest.main()

0 commit comments

Comments
 (0)