rohit
Deploy RAG Pipeline with FastAPI and ML capabilities
91f974c
raw
history blame
5.48 kB
from haystack import Document, Pipeline
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.builders import ChatPromptBuilder
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
from datasets import load_dataset
from haystack.dataclasses import ChatMessage
from typing import Optional, List, Union, Dict
from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG
class RAGPipeline:
def __init__(
self,
google_api_key: str,
dataset_config: Union[str, DatasetConfig],
documents: Optional[List[Union[str, Document]]] = None,
embedding_model: Optional[str] = None,
llm_model: Optional[str] = None
):
"""
Initialize the RAG Pipeline.
Args:
google_api_key: API key for Google AI services
dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object
documents: Optional list of documents to use instead of loading from a dataset
embedding_model: Optional override for embedding model
llm_model: Optional override for LLM model
"""
# Load configuration
if isinstance(dataset_config, str):
if dataset_config not in DATASET_CONFIGS:
raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}")
self.config = DATASET_CONFIGS[dataset_config]
else:
self.config = dataset_config
# Load documents either from provided list or dataset
if documents is not None:
self.documents = documents
else:
dataset = load_dataset(self.config.name, split=self.config.split)
# Create documents with metadata based on configuration
self.documents = []
for doc in dataset:
# Create metadata dictionary from configured fields
meta = {}
if self.config.fields:
for meta_key, dataset_field in self.config.fields.items():
if dataset_field in doc:
meta[meta_key] = doc[dataset_field]
# Create document with content and metadata
document = Document(
content=doc[self.config.content_field],
meta=meta
)
self.documents.append(document)
# print 10 documents
for doc in self.documents[:10]:
print(f"Content: {doc.content}")
print(f"Metadata: {doc.meta}")
print("-"*100)
# Initialize components
self.document_store = InMemoryDocumentStore()
self.doc_embedder = SentenceTransformersDocumentEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"]
)
self.text_embedder = SentenceTransformersTextEmbedder(
model=embedding_model or MODEL_CONFIG["embedding_model"]
)
self.retriever = InMemoryEmbeddingRetriever(self.document_store)
# Warm up the document embedder
self.doc_embedder.warm_up()
# Initialize prompt template
template = [
ChatMessage.from_user(self.config.prompt_template)
]
self.prompt_builder = ChatPromptBuilder(template=template)
# Initialize the generator
self.generator = GoogleAIGeminiChatGenerator(
model=llm_model or MODEL_CONFIG["llm_model"]
)
# Index documents
self._index_documents(self.documents)
# Build pipeline
self.pipeline = self._build_pipeline()
@classmethod
def from_preset(cls, google_api_key: str, preset_name: str):
"""
Create a pipeline from a preset configuration.
Args:
google_api_key: API key for Google AI services
preset_name: Name of the preset configuration to use
"""
return cls(google_api_key=google_api_key, dataset_config=preset_name)
def _index_documents(self, documents):
# Embed and index documents
docs_with_embeddings = self.doc_embedder.run(documents)
self.document_store.write_documents(docs_with_embeddings["documents"])
def _build_pipeline(self):
pipeline = Pipeline()
pipeline.add_component("text_embedder", self.text_embedder)
pipeline.add_component("retriever", self.retriever)
pipeline.add_component("prompt_builder", self.prompt_builder)
pipeline.add_component("llm", self.generator)
# Connect components
pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
pipeline.connect("retriever", "prompt_builder")
pipeline.connect("prompt_builder.prompt", "llm.messages")
return pipeline
def answer_question(self, question: str) -> str:
"""Run the RAG pipeline to answer a question"""
result = self.pipeline.run({
"text_embedder": {"text": question},
"prompt_builder": {"question": question}
})
return result["llm"]["replies"][0].text