|
from haystack import Document, Pipeline |
|
from haystack.document_stores.in_memory import InMemoryDocumentStore |
|
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder |
|
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever |
|
from haystack.components.builders import ChatPromptBuilder |
|
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator |
|
from datasets import load_dataset |
|
from haystack.dataclasses import ChatMessage |
|
from typing import Optional, List, Union, Dict |
|
from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG |
|
|
|
class RAGPipeline: |
|
def __init__( |
|
self, |
|
google_api_key: str, |
|
dataset_config: Union[str, DatasetConfig], |
|
documents: Optional[List[Union[str, Document]]] = None, |
|
embedding_model: Optional[str] = None, |
|
llm_model: Optional[str] = None |
|
): |
|
""" |
|
Initialize the RAG Pipeline. |
|
|
|
Args: |
|
google_api_key: API key for Google AI services |
|
dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object |
|
documents: Optional list of documents to use instead of loading from a dataset |
|
embedding_model: Optional override for embedding model |
|
llm_model: Optional override for LLM model |
|
""" |
|
|
|
if isinstance(dataset_config, str): |
|
if dataset_config not in DATASET_CONFIGS: |
|
raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}") |
|
self.config = DATASET_CONFIGS[dataset_config] |
|
else: |
|
self.config = dataset_config |
|
|
|
|
|
if documents is not None: |
|
self.documents = documents |
|
else: |
|
dataset = load_dataset(self.config.name, split=self.config.split) |
|
|
|
self.documents = [] |
|
for doc in dataset: |
|
|
|
meta = {} |
|
if self.config.fields: |
|
for meta_key, dataset_field in self.config.fields.items(): |
|
if dataset_field in doc: |
|
meta[meta_key] = doc[dataset_field] |
|
|
|
|
|
document = Document( |
|
content=doc[self.config.content_field], |
|
meta=meta |
|
) |
|
self.documents.append(document) |
|
|
|
|
|
for doc in self.documents[:10]: |
|
print(f"Content: {doc.content}") |
|
print(f"Metadata: {doc.meta}") |
|
print("-"*100) |
|
|
|
|
|
self.document_store = InMemoryDocumentStore() |
|
self.doc_embedder = SentenceTransformersDocumentEmbedder( |
|
model=embedding_model or MODEL_CONFIG["embedding_model"] |
|
) |
|
self.text_embedder = SentenceTransformersTextEmbedder( |
|
model=embedding_model or MODEL_CONFIG["embedding_model"] |
|
) |
|
self.retriever = InMemoryEmbeddingRetriever(self.document_store) |
|
|
|
|
|
self.doc_embedder.warm_up() |
|
|
|
|
|
template = [ |
|
ChatMessage.from_user(self.config.prompt_template) |
|
] |
|
self.prompt_builder = ChatPromptBuilder(template=template) |
|
|
|
|
|
self.generator = GoogleAIGeminiChatGenerator( |
|
model=llm_model or MODEL_CONFIG["llm_model"] |
|
) |
|
|
|
|
|
self._index_documents(self.documents) |
|
|
|
|
|
self.pipeline = self._build_pipeline() |
|
|
|
@classmethod |
|
def from_preset(cls, google_api_key: str, preset_name: str): |
|
""" |
|
Create a pipeline from a preset configuration. |
|
|
|
Args: |
|
google_api_key: API key for Google AI services |
|
preset_name: Name of the preset configuration to use |
|
""" |
|
return cls(google_api_key=google_api_key, dataset_config=preset_name) |
|
|
|
def _index_documents(self, documents): |
|
|
|
docs_with_embeddings = self.doc_embedder.run(documents) |
|
self.document_store.write_documents(docs_with_embeddings["documents"]) |
|
|
|
def _build_pipeline(self): |
|
pipeline = Pipeline() |
|
pipeline.add_component("text_embedder", self.text_embedder) |
|
pipeline.add_component("retriever", self.retriever) |
|
pipeline.add_component("prompt_builder", self.prompt_builder) |
|
pipeline.add_component("llm", self.generator) |
|
|
|
|
|
pipeline.connect("text_embedder.embedding", "retriever.query_embedding") |
|
pipeline.connect("retriever", "prompt_builder") |
|
pipeline.connect("prompt_builder.prompt", "llm.messages") |
|
|
|
return pipeline |
|
|
|
def answer_question(self, question: str) -> str: |
|
"""Run the RAG pipeline to answer a question""" |
|
result = self.pipeline.run({ |
|
"text_embedder": {"text": question}, |
|
"prompt_builder": {"question": question} |
|
}) |
|
return result["llm"]["replies"][0].text |