File size: 5,966 Bytes
aa2bec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import streamlit as st
import os
import zipfile
import shutil
from io import BytesIO
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import faiss
import uuid
from dotenv import load_dotenv

# Load environment variables
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
RAG_ACCESS_KEY = os.getenv("RAG_ACCESS_KEY")

# Initialize session state
if "vectorstore" not in st.session_state:
    st.session_state.vectorstore = None
if "history" not in st.session_state:
    st.session_state.history = []
if "authenticated" not in st.session_state:
    st.session_state.authenticated = False

# Sidebar
with st.sidebar:
    st.header("RAG Control Panel")
    api_key_input = st.text_input("Enter RAG Access Key", type="password")
    
    # Authentication
    if st.button("Authenticate"):
        if api_key_input == RAG_ACCESS_KEY:
            st.session_state.authenticated = True
            st.success("Authentication successful!")
        else:
            st.error("Invalid API key.")
    
    # File uploader
    if st.session_state.authenticated:
        input_type = st.selectbox("Select Input Type", ["Single PDF", "Folder/Zip of PDFs"])
        input_data = None
        if input_type == "Single PDF":
            input_data = st.file_uploader("Upload a PDF file", type=["pdf"])
        else:
            input_data = st.file_uploader("Upload a folder or zip of PDFs", type=["zip"])
        
        if st.button("Process Files") and input_data is not None:
            with st.spinner("Processing files..."):
                vector_store = process_input(input_type, input_data)
                st.session_state.vectorstore = vector_store
                st.success("Files processed successfully. You can now ask questions.")
    
    # Display chat history
    st.subheader("Chat History")
    for i, (q, a) in enumerate(st.session_state.history):
        st.write(f"**Q{i+1}:** {q}")
        st.write(f"**A{i+1}:** {a}")
        st.markdown("---")

# Main app
def main():
    st.title("RAG Q&A App with Mistral AI")
    
    if not st.session_state.authenticated:
        st.warning("Please authenticate with your API key in the sidebar.")
        return
    
    if st.session_state.vectorstore is None:
        st.info("Please upload and process a PDF or folder/zip of PDFs in the sidebar.")
        return
    
    query = st.text_input("Enter your question:")
    if st.button("Submit") and query:
        with st.spinner("Generating answer..."):
            answer = answer_question(st.session_state.vectorstore, query)
            st.session_state.history.append((query, answer))
            st.write("**Answer:**", answer)

def process_input(input_type, input_data):
    # Create uploads directory
    os.makedirs("uploads", exist_ok=True)
    
    documents = ""
    if input_type == "Single PDF":
        pdf_reader = PdfReader(input_data)
        for page in pdf_reader.pages:
            documents += page.extract_text() or ""
    else:
        # Handle zip file
        zip_path = "uploads/uploaded.zip"
        with open(zip_path, "wb") as f:
            f.write(input_data.getvalue())
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall("uploads/extracted")
        
        # Process all PDFs in extracted folder
        for root, _, files in os.walk("uploads/extracted"):
            for file in files:
                if file.endswith(".pdf"):
                    pdf_path = os.path.join(root, file)
                    pdf_reader = PdfReader(pdf_path)
                    for page in pdf_reader.pages:
                        documents += page.extract_text() or ""
        
        # Clean up extracted files
        shutil.rmtree("uploads/extracted", ignore_errors=True)
        os.remove(zip_path)
    
    # Split text
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    texts = text_splitter.split_text(documents)
    
    # Create embeddings
    hf_embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-mpnet-base-v2",
        model_kwargs={'device': 'cpu'}
    )
    
    # Initialize FAISS
    dimension = len(hf_embeddings.embed_query("sample text"))
    index = faiss.IndexFlatL2(dimension)
    vector_store = FAISS(
        embedding_function=hf_embeddings,
        index=index,
        docstore=InMemoryDocstore({}),
        index_to_docstore_id={}
    )
    
    # Add texts to vector store
    uuids = [str(uuid.uuid4()) for _ in range(len(texts))]
    vector_store.add_texts(texts, ids=uuids)
    
    # Save vector store locally
    vector_store.save_local("vectorstore/faiss_index")
    
    return vector_store

def answer_question(vectorstore, query):
    llm = HuggingFaceHub(
        repo_id="mistralai/Mistral-7B-Instruct-v0.1",
        model_kwargs={"temperature": 0.7, "max_length": 512},
        huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN
    )
    
    retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
    
    prompt_template = PromptTemplate(
        template="Use the provided context to answer the question concisely:\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:",
        input_variables=["context", "question"]
    )
    
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=False,
        chain_type_kwargs={"prompt": prompt_template}
    )
    
    result = qa_chain({"query": query})
    return result["result"].split("Answer:")[-1].strip()

if __name__ == "__main__":
    main()