File size: 3,496 Bytes
df56dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from haystack.nodes import WebRetriever
from haystack.schema import Document
from typing import List 
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import AnswerParser, PromptNode, PromptTemplate
from haystack import Pipeline
from haystack.nodes import  DensePassageRetriever
import os
from dotenv import load_dotenv

def initialize_documents(serp_key, nl_query):
    """
    Initialize documents retrieved from the SERP API.

    Args:
        serp_key (str): API key for the SERP API.
        nl_query (str): Natural language query to retrieve documents for.
        
    """
     # Initialize WebRetriever
    retriever = WebRetriever(api_key=serp_key, 
                            mode="preprocessed_documents",
                            top_k=100)

    # Retrieve documents based a natural language query
    documents : List[Document] = retriever.retrieve(query=nl_query)

    return documents

def initialize_faiss_document_store(documents):
    """
    Initialize a FAISS document store and retriever.

    Args:
        documents (List[Document]): List of documents to be stored in the document store.

    Returns:
        document_store (FAISSDocumentStore): FAISS document store.
        retriever (DensePassageRetriever): Dense passage retriever.
    """
    
    # Initialize document store
    document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", return_embedding=True)

    retriever = DensePassageRetriever(
                document_store=document_store,
                query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
                passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
                use_gpu=True,
                embed_title=True,
    )

    # Delete existing documents in document store
    document_store.delete_documents()
    document_store.write_documents(documents)

    # Add documents embeddings to index
    document_store.update_embeddings(retriever=retriever)

    return document_store, retriever

def initialize_rag_pipeline(retriever, openai_key):
    """
    Initialize a pipeline for RAG-based question answering.

    Args:
        retriever (DensePassageRetriever): Dense passage retriever.
        openai_key (str): API key for OpenAI.

    Returns:
        query_pipeline (Pipeline): Pipeline for RAG-based question answering.
    """
    prompt_template = PromptTemplate(prompt = """"Answer the following query based on the provided context. If the context does
                                                not include an answer, reply with 'The data does not contain information related to the question'.\n
                                                Query: {query}\n
                                                Documents: {join(documents)}
                                                Answer: 
                                            """,
                                            output_parser=AnswerParser())
    prompt_node = PromptNode(model_name_or_path = "gpt-4",
                            api_key = openai_key,
                            default_prompt_template = prompt_template,
                            max_length = 500,
                            model_kwargs={"stream":True})

    query_pipeline = Pipeline()
    query_pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
    query_pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Retriever"])

    return query_pipeline