File size: 5,479 Bytes
91f974c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from haystack import Document, Pipeline
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.components.builders import ChatPromptBuilder
from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
from datasets import load_dataset
from haystack.dataclasses import ChatMessage
from typing import Optional, List, Union, Dict
from .config import DatasetConfig, DATASET_CONFIGS, MODEL_CONFIG

class RAGPipeline:
    def __init__(
        self,
        google_api_key: str,
        dataset_config: Union[str, DatasetConfig],
        documents: Optional[List[Union[str, Document]]] = None,
        embedding_model: Optional[str] = None,
        llm_model: Optional[str] = None
    ):
        """
        Initialize the RAG Pipeline.
        
        Args:
            google_api_key: API key for Google AI services
            dataset_config: Either a string key from DATASET_CONFIGS or a DatasetConfig object
            documents: Optional list of documents to use instead of loading from a dataset
            embedding_model: Optional override for embedding model
            llm_model: Optional override for LLM model
        """
        # Load configuration
        if isinstance(dataset_config, str):
            if dataset_config not in DATASET_CONFIGS:
                raise ValueError(f"Dataset config '{dataset_config}' not found. Available configs: {list(DATASET_CONFIGS.keys())}")
            self.config = DATASET_CONFIGS[dataset_config]
        else:
            self.config = dataset_config

        # Load documents either from provided list or dataset
        if documents is not None:
            self.documents = documents
        else:
            dataset = load_dataset(self.config.name, split=self.config.split)
            # Create documents with metadata based on configuration
            self.documents = []
            for doc in dataset:
                # Create metadata dictionary from configured fields
                meta = {}
                if self.config.fields:
                    for meta_key, dataset_field in self.config.fields.items():
                        if dataset_field in doc:
                            meta[meta_key] = doc[dataset_field]
                
                # Create document with content and metadata
                document = Document(
                    content=doc[self.config.content_field],
                    meta=meta
                )
                self.documents.append(document)

        # print 10 documents
        for doc in self.documents[:10]:
            print(f"Content: {doc.content}")
            print(f"Metadata: {doc.meta}")
            print("-"*100)
        
        # Initialize components
        self.document_store = InMemoryDocumentStore()
        self.doc_embedder = SentenceTransformersDocumentEmbedder(
            model=embedding_model or MODEL_CONFIG["embedding_model"]
        )
        self.text_embedder = SentenceTransformersTextEmbedder(
            model=embedding_model or MODEL_CONFIG["embedding_model"]
        )
        self.retriever = InMemoryEmbeddingRetriever(self.document_store)
        
        # Warm up the document embedder
        self.doc_embedder.warm_up()
        
        # Initialize prompt template
        template = [
            ChatMessage.from_user(self.config.prompt_template)
        ]
        self.prompt_builder = ChatPromptBuilder(template=template)

        # Initialize the generator
        self.generator = GoogleAIGeminiChatGenerator(
            model=llm_model or MODEL_CONFIG["llm_model"]
        )
        
        # Index documents
        self._index_documents(self.documents)
        
        # Build pipeline
        self.pipeline = self._build_pipeline()

    @classmethod
    def from_preset(cls, google_api_key: str, preset_name: str):
        """
        Create a pipeline from a preset configuration.
        
        Args:
            google_api_key: API key for Google AI services
            preset_name: Name of the preset configuration to use
        """
        return cls(google_api_key=google_api_key, dataset_config=preset_name)

    def _index_documents(self, documents):
        # Embed and index documents
        docs_with_embeddings = self.doc_embedder.run(documents)
        self.document_store.write_documents(docs_with_embeddings["documents"])
    
    def _build_pipeline(self):
        pipeline = Pipeline()
        pipeline.add_component("text_embedder", self.text_embedder)
        pipeline.add_component("retriever", self.retriever)
        pipeline.add_component("prompt_builder", self.prompt_builder)
        pipeline.add_component("llm", self.generator)
        
        # Connect components
        pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
        pipeline.connect("retriever", "prompt_builder")
        pipeline.connect("prompt_builder.prompt", "llm.messages")
        
        return pipeline
    
    def answer_question(self, question: str) -> str:
        """Run the RAG pipeline to answer a question"""
        result = self.pipeline.run({
            "text_embedder": {"text": question},
            "prompt_builder": {"question": question}
        })
        return result["llm"]["replies"][0].text