|
14 | 14 |
|
15 | 15 | from collections.abc import Sequence |
16 | 16 | import dataclasses |
| 17 | +import inspect |
17 | 18 | import textwrap |
18 | 19 | from typing import Type |
19 | 20 | from unittest import mock |
@@ -1118,5 +1119,89 @@ def test_extractions_overlap(self, ext1, ext2, expected): |
1118 | 1119 | self.assertEqual(result, expected) |
1119 | 1120 |
|
1120 | 1121 |
|
| 1122 | +class AnnotateDocumentsGeneratorTest(absltest.TestCase): |
| 1123 | + """Tests that annotate_documents uses 'yield from' for proper delegation.""" |
| 1124 | + |
| 1125 | + def setUp(self): |
| 1126 | + super().setUp() |
| 1127 | + self.mock_language_model = self.enter_context( |
| 1128 | + mock.patch.object(gemini, "GeminiLanguageModel", autospec=True) |
| 1129 | + ) |
| 1130 | + |
| 1131 | + def mock_infer(batch_prompts, **_): |
| 1132 | + """Return medication extractions based on prompt content.""" |
| 1133 | + for prompt in batch_prompts: |
| 1134 | + if "Ibuprofen" in prompt: |
| 1135 | + text = textwrap.dedent(f"""\ |
| 1136 | + ```yaml |
| 1137 | + {data.EXTRACTIONS_KEY}: |
| 1138 | + - medication: "Ibuprofen" |
| 1139 | + medication_index: 4 |
| 1140 | + ```""") |
| 1141 | + elif "Cefazolin" in prompt: |
| 1142 | + text = textwrap.dedent(f"""\ |
| 1143 | + ```yaml |
| 1144 | + {data.EXTRACTIONS_KEY}: |
| 1145 | + - medication: "Cefazolin" |
| 1146 | + medication_index: 4 |
| 1147 | + ```""") |
| 1148 | + else: |
| 1149 | + text = f"```yaml\n{data.EXTRACTIONS_KEY}: []\n```" |
| 1150 | + yield [types.ScoredOutput(score=1.0, output=text)] |
| 1151 | + |
| 1152 | + self.mock_language_model.infer.side_effect = mock_infer |
| 1153 | + |
| 1154 | + self.annotator = annotation.Annotator( |
| 1155 | + language_model=self.mock_language_model, |
| 1156 | + prompt_template=prompting.PromptTemplateStructured(description=""), |
| 1157 | + ) |
| 1158 | + |
| 1159 | + def test_yields_documents_not_generators(self): |
| 1160 | + """Verifies annotate_documents yields AnnotatedDocument, not generators.""" |
| 1161 | + docs = [ |
| 1162 | + data.Document( |
| 1163 | + text="Patient took 400 mg PO Ibuprofen q4h for two days.", |
| 1164 | + document_id="doc1", |
| 1165 | + ), |
| 1166 | + data.Document( |
| 1167 | + text="Patient was given 250 mg IV Cefazolin TID for one week.", |
| 1168 | + document_id="doc2", |
| 1169 | + ), |
| 1170 | + ] |
| 1171 | + |
| 1172 | + results = list( |
| 1173 | + self.annotator.annotate_documents( |
| 1174 | + docs, |
| 1175 | + resolver=resolver_lib.Resolver( |
| 1176 | + fence_output=True, |
| 1177 | + format_type=data.FormatType.YAML, |
| 1178 | + extraction_index_suffix=resolver_lib.DEFAULT_INDEX_SUFFIX, |
| 1179 | + ), |
| 1180 | + show_progress=False, |
| 1181 | + debug=False, |
| 1182 | + ) |
| 1183 | + ) |
| 1184 | + |
| 1185 | + self.assertLen(results, 2) |
| 1186 | + self.assertFalse( |
| 1187 | + any(inspect.isgenerator(item) for item in results), |
| 1188 | + msg="Must use 'yield from' to delegate, not 'yield'", |
| 1189 | + ) |
| 1190 | + meds_doc1 = { |
| 1191 | + e.extraction_text |
| 1192 | + for e in results[0].extractions |
| 1193 | + if e.extraction_class == "medication" |
| 1194 | + } |
| 1195 | + meds_doc2 = { |
| 1196 | + e.extraction_text |
| 1197 | + for e in results[1].extractions |
| 1198 | + if e.extraction_class == "medication" |
| 1199 | + } |
| 1200 | + self.assertIn("Ibuprofen", meds_doc1) |
| 1201 | + self.assertNotIn("Cefazolin", meds_doc1) |
| 1202 | + self.assertIn("Cefazolin", meds_doc2) |
| 1203 | + self.assertNotIn("Ibuprofen", meds_doc2) |
| 1204 | + |
| 1205 | + |
1121 | 1206 | if __name__ == "__main__": |
1122 | 1207 | absltest.main() |
0 commit comments