Skip to content

Commit c689c05

Browse files
refactor: changing LLMMetadaExtractor to use chat generators (#188)
* chaning to ChatGenerators * linting * updating tests
1 parent 4aa3bf7 commit c689c05

File tree

2 files changed

+55
-47
lines changed

2 files changed

+55
-47
lines changed

haystack_experimental/components/extractors/llm_metadata_extractor.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
from haystack import Document, component, default_from_dict, default_to_dict, logging
1212
from haystack.components.builders import PromptBuilder
13-
from haystack.components.generators import AzureOpenAIGenerator, OpenAIGenerator
13+
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
1414
from haystack.components.preprocessors import DocumentSplitter
15+
from haystack.dataclasses import ChatMessage
1516
from haystack.lazy_imports import LazyImport
1617
from haystack.utils import deserialize_callable, deserialize_secrets_inplace
1718
from jinja2 import meta
@@ -20,10 +21,10 @@
2021
from haystack_experimental.util.utils import expand_page_range
2122

2223
with LazyImport(message="Run 'pip install \"amazon-bedrock-haystack>=1.0.2\"'") as amazon_bedrock_generator:
23-
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator
24+
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
2425

2526
with LazyImport(message="Run 'pip install \"google-vertex-haystack>=2.0.0\"'") as vertex_ai_gemini_generator:
26-
from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator
27+
from haystack_integrations.components.generators.google_vertex.chat.gemini import VertexAIGeminiChatGenerator
2728
from vertexai.generative_models import GenerationConfig
2829

2930

@@ -192,7 +193,6 @@ def __init__( # pylint: disable=R0917
192193
f"Prompt must have exactly one variable called 'document'. Found {','.join(variables)} in the prompt."
193194
)
194195
self.builder = PromptBuilder(prompt, required_variables=variables)
195-
196196
self.raise_on_failure = raise_on_failure
197197
self.expected_keys = expected_keys or []
198198
self.generator_api = generator_api if isinstance(generator_api, LLMProvider) \
@@ -207,20 +207,22 @@ def __init__( # pylint: disable=R0917
207207
def _init_generator(
208208
generator_api: LLMProvider,
209209
generator_api_params: Optional[Dict[str, Any]]
210-
) -> Union[OpenAIGenerator, AzureOpenAIGenerator, "AmazonBedrockGenerator", "VertexAIGeminiGenerator"]:
210+
) -> Union[
211+
OpenAIChatGenerator, AzureOpenAIChatGenerator, "AmazonBedrockChatGenerator", "VertexAIGeminiChatGenerator"
212+
]:
211213
"""
212214
Initialize the chat generator based on the specified API provider and parameters.
213215
"""
214216
if generator_api == LLMProvider.OPENAI:
215-
return OpenAIGenerator(**generator_api_params)
217+
return OpenAIChatGenerator(**generator_api_params)
216218
elif generator_api == LLMProvider.OPENAI_AZURE:
217-
return AzureOpenAIGenerator(**generator_api_params)
219+
return AzureOpenAIChatGenerator(**generator_api_params)
218220
elif generator_api == LLMProvider.AWS_BEDROCK:
219221
amazon_bedrock_generator.check()
220-
return AmazonBedrockGenerator(**generator_api_params)
222+
return AmazonBedrockChatGenerator(**generator_api_params)
221223
elif generator_api == LLMProvider.GOOGLE_VERTEX:
222224
vertex_ai_gemini_generator.check()
223-
return VertexAIGeminiGenerator(**generator_api_params)
225+
return VertexAIGeminiChatGenerator(**generator_api_params)
224226
else:
225227
raise ValueError(f"Unsupported generator API: {generator_api}")
226228

@@ -318,8 +320,8 @@ def _prepare_prompts(
318320
self,
319321
documents: List[Document],
320322
expanded_range: Optional[List[int]] = None
321-
) -> List[Union[str, None]]:
322-
all_prompts: List[Union[str, None]] = []
323+
) -> List[Union[ChatMessage, None]]:
324+
all_prompts: List[Union[ChatMessage, None]] = []
323325
for document in documents:
324326
if not document.content:
325327
logger.warning(
@@ -341,19 +343,23 @@ def _prepare_prompts(
341343
doc_copy = document
342344

343345
prompt_with_doc = self.builder.run(
344-
template=self.prompt,
345-
template_variables={"document": doc_copy}
346-
)
347-
all_prompts.append(prompt_with_doc["prompt"])
346+
template=self.prompt,
347+
template_variables={"document": doc_copy}
348+
)
349+
350+
# build a ChatMessage with the prompt
351+
message = ChatMessage.from_user(prompt_with_doc["prompt"])
352+
all_prompts.append(message)
353+
348354
return all_prompts
349355

350-
def _run_on_thread(self, prompt: Optional[str]) -> Dict[str, Any]:
356+
def _run_on_thread(self, prompt: Optional[ChatMessage]) -> Dict[str, Any]:
351357
# If prompt is None, return an empty dictionary
352358
if prompt is None:
353359
return {"replies": ["{}"]}
354360

355361
try:
356-
result = self.llm_provider.run(prompt=prompt)
362+
result = self.llm_provider.run(messages=[prompt])
357363
except Exception as e:
358364
logger.error(
359365
"LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
@@ -398,7 +404,7 @@ def run(self, documents: List[Document], page_range: Optional[List[Union[str, in
398404
if page_range:
399405
expanded_range = expand_page_range(page_range)
400406

401-
# Create prompts for each document
407+
# Create ChatMessage prompts for each document
402408
all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range)
403409

404410
# Run the LLM on each prompt
@@ -414,7 +420,7 @@ def run(self, documents: List[Document], page_range: Optional[List[Union[str, in
414420
failed_documents.append(document)
415421
continue
416422

417-
parsed_metadata = self._extract_metadata(result["replies"][0])
423+
parsed_metadata = self._extract_metadata(result["replies"][0].text)
418424
if "error" in parsed_metadata:
419425
document.meta["metadata_extraction_error"] = parsed_metadata["error"]
420426
document.meta["metadata_extraction_response"] = result["replies"][0]

test/components/extractors/test_llm_metadata_extractor.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
import boto3
21
import os
3-
import pytest
42
from unittest.mock import MagicMock
53

6-
from haystack import Pipeline, Document
4+
import boto3
5+
import pytest
6+
from haystack import Document, Pipeline
77
from haystack.components.builders import PromptBuilder
88
from haystack.components.writers import DocumentWriter
9+
from haystack.dataclasses import ChatMessage
910
from haystack.document_stores.in_memory import InMemoryDocumentStore
10-
from haystack_experimental.components.extractors import LLMMetadataExtractor
11-
from haystack_experimental.components.extractors import LLMProvider
11+
12+
from haystack_experimental.components.extractors import LLMMetadataExtractor, LLMProvider
1213

1314

1415
class TestLLMMetadataExtractor:
@@ -95,7 +96,10 @@ def test_to_dict_openai(self, monkeypatch):
9596
"model": "gpt-4o-mini",
9697
"organization": None,
9798
"streaming_callback": None,
98-
"system_prompt": None,
99+
"max_retries": None,
100+
"timeout": None,
101+
"tools": None,
102+
"tools_strict": False,
99103
},
100104
"max_workers": 3,
101105
},
@@ -106,11 +110,7 @@ def test_to_dict_aws_bedrock(self, boto3_session_mock):
106110
prompt="some prompt that was used with the LLM {{document.content}}",
107111
expected_keys=["key1", "key2"],
108112
generator_api=LLMProvider.AWS_BEDROCK,
109-
generator_api_params={
110-
"model": "meta.llama.test",
111-
"max_length": 100,
112-
"truncate": False,
113-
},
113+
generator_api_params={"model": "meta.llama.test"},
114114
raise_on_failure=True,
115115
)
116116
extractor_dict = extractor.to_dict()
@@ -146,11 +146,11 @@ def test_to_dict_aws_bedrock(self, boto3_session_mock):
146146
"strict": False,
147147
},
148148
"model": "meta.llama.test",
149-
"model_family": None,
150-
"max_length": 100,
151-
"truncate": False,
149+
"stop_words": [],
150+
"generation_kwargs": {},
152151
"streaming_callback": None,
153152
"boto3_config": None,
153+
"tools": None,
154154
},
155155
"expected_keys": ["key1", "key2"],
156156
"page_range": None,
@@ -179,7 +179,6 @@ def test_from_dict_openai(self, monkeypatch):
179179
"model": "gpt-4o-mini",
180180
"organization": None,
181181
"streaming_callback": None,
182-
"system_prompt": None,
183182
},
184183
},
185184
}
@@ -225,10 +224,11 @@ def test_from_dict_aws_bedrock(self, boto3_session_mock):
225224
"strict": False,
226225
},
227226
"model": "meta.llama.test",
228-
"max_length": 200,
229-
"truncate": False,
227+
"stop_words": [],
228+
"generation_kwargs": {},
230229
"streaming_callback": None,
231230
"boto3_config": None,
231+
"tools": None,
232232
},
233233
"expected_keys": ["key1", "key2"],
234234
"page_range": None,
@@ -244,8 +244,6 @@ def test_from_dict_aws_bedrock(self, boto3_session_mock):
244244
== "some prompt that was used with the LLM {{document.content}}"
245245
)
246246
assert extractor.generator_api == LLMProvider.AWS_BEDROCK
247-
assert extractor.llm_provider.max_length == 200
248-
assert extractor.llm_provider.truncate is False
249247
assert extractor.llm_provider.model == "meta.llama.test"
250248

251249
def test_warm_up(self, monkeypatch):
@@ -288,7 +286,7 @@ def test_extract_metadata_missing_key(self, monkeypatch, caplog):
288286
def test_prepare_prompts(self, monkeypatch):
289287
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
290288
extractor = LLMMetadataExtractor(
291-
prompt="prompt {{document.content}}",
289+
prompt="some_user_definer_prompt {{document.content}}",
292290
generator_api=LLMProvider.OPENAI,
293291
)
294292
docs = [
@@ -300,15 +298,16 @@ def test_prepare_prompts(self, monkeypatch):
300298
),
301299
]
302300
prompts = extractor._prepare_prompts(docs)
301+
303302
assert prompts == [
304-
"prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework",
305-
"prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library",
303+
ChatMessage.from_dict({"_role": "user", "_meta": {}, "_name": None, "_content": [{"text": "some_user_definer_prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework"}]}),
304+
ChatMessage.from_dict({"_role": "user", "_meta": {}, "_name": None, "_content": [{"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"}]})
306305
]
307306

308307
def test_prepare_prompts_empty_document(self, monkeypatch):
309308
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
310309
extractor = LLMMetadataExtractor(
311-
prompt="prompt {{document.content}}",
310+
prompt="some_user_definer_prompt {{document.content}}",
312311
generator_api=LLMProvider.OPENAI,
313312
)
314313
docs = [
@@ -320,13 +319,14 @@ def test_prepare_prompts_empty_document(self, monkeypatch):
320319
prompts = extractor._prepare_prompts(docs)
321320
assert prompts == [
322321
None,
323-
"prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library",
322+
ChatMessage.from_dict(
323+
{"_role": "user", "_meta": {}, "_name": None, "_content": [{"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library"}]})
324324
]
325325

326326
def test_prepare_prompts_expanded_range(self, monkeypatch):
327327
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
328328
extractor = LLMMetadataExtractor(
329-
prompt="prompt {{document.content}}",
329+
prompt="some_user_definer_prompt {{document.content}}",
330330
generator_api=LLMProvider.OPENAI,
331331
page_range=["1-2"],
332332
)
@@ -336,9 +336,11 @@ def test_prepare_prompts_expanded_range(self, monkeypatch):
336336
)
337337
]
338338
prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2])
339-
assert prompts == [
340-
"prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\f",
341-
]
339+
340+
assert prompts == [ChatMessage.from_dict({"_role": "user",
341+
"_meta": {},
342+
"_name": None,
343+
"_content": [{"text": "some_user_definer_prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\x0cPage 2\x0c"}]})]
342344

343345
def test_run_no_documents(self, monkeypatch):
344346
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")

0 commit comments

Comments
 (0)