File size: 5,479 Bytes
91f974c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from haystack import Document, Pipeline
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.builders import ChatPromptBuilder
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
from datasets import load_dataset
from haystack.dataclasses import ChatMessage
from typing import Optional, List, Union, Dict
from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG
class RAGPipeline:
def __init__(
self,
google_api_key: str,
dataset_config: Union[str, DatasetConfig],
documents: Optional[List[Union[str, Document]]] = None,
embedding_model: Optional[str] = None,
llm_model: Optional[str] = None
):
"""
Initialize the RAG Pipeline.
Args:
google_api_key: API key for Google AI services
dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object
documents: Optional list of documents to use instead of loading from a dataset
embedding_model: Optional override for embedding model
llm_model: Optional override for LLM model
"""
# Load configuration
if isinstance(dataset_config, str):
if dataset_config not in DATASET_CONFIGS:
raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}")
self.config = DATASET_CONFIGS[dataset_config]
else:
self.config = dataset_config
# Load documents either from provided list or dataset
if documents is not None:
self.documents = documents
else:
dataset = load_dataset(self.config.name, split=self.config.split)
# Create documents with metadata based on configuration
self.documents = []
for doc in dataset:
# Create metadata dictionary from configured fields
meta = {}
if self.config.fields:
for meta_key, dataset_field in self.config.fields.items():
if dataset_field in doc:
meta[meta_key] = doc[dataset_field]
# Create document with content and metadata
document = Document(
content=doc[self.config.content_field],
meta=meta
)
self.documents.append(document)
# print 10 documents
for doc in self.documents[:10]:
print(f"Content: {doc.content}")
print(f"Metadata: {doc.meta}")
print("-"*100)
# Initialize components
self.document_store = InMemoryDocumentStore()
self.doc_embedder = SentenceTransformersDocumentEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"]
)
self.text_embedder = SentenceTransformersTextEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"]
)
self.retriever = InMemoryEmbeddingRetriever(self.document_store)
# Warm up the document embedder
self.doc_embedder.warm_up()
# Initialize prompt template
template = [
ChatMessage.from_user(self.config.prompt_template)
]
self.prompt_builder = ChatPromptBuilder(template=template)
# Initialize the generator
self.generator = GoogleAIGeminiChatGenerator(
model=llm_model or MODEL_CONFIG["llm_model"]
)
# Index documents
self._index_documents(self.documents)
# Build pipeline
self.pipeline = self._build_pipeline()
@classmethod
def from_preset(cls, google_api_key: str, preset_name: str):
"""
Create a pipeline from a preset configuration.
Args:
google_api_key: API key for Google AI services
preset_name: Name of the preset configuration to use
"""
return cls(google_api_key=google_api_key, dataset_config=preset_name)
def _index_documents(self, documents):
# Embed and index documents
docs_with_embeddings = self.doc_embedder.run(documents)
self.document_store.write_documents(docs_with_embeddings["documents"])
def _build_pipeline(self):
pipeline = Pipeline()
pipeline.add_component("text_embedder", self.text_embedder)
pipeline.add_component("retriever", self.retriever)
pipeline.add_component("prompt_builder", self.prompt_builder)
pipeline.add_component("llm", self.generator)
# Connect components
pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
pipeline.connect("retriever", "prompt_builder")
pipeline.connect("prompt_builder.prompt", "llm.messages")
return pipeline
def answer_question(self, question: str) -> str:
"""Run the RAG pipeline to answer a question"""
result = self.pipeline.run({
"text_embedder": {"text": question},
"prompt_builder": {"question": question}
})
return result["llm"]["replies"][0].text |