samim2024 commited on
Commit
aa2bec3
·
verified ·
1 Parent(s): 32ec859

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import zipfile
4
+ import shutil
5
+ from io import BytesIO
6
+ from PyPDF2 import PdfReader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_community.docstore.in_memory import InMemoryDocstore
11
+ from langchain_community.llms import HuggingFaceHub
12
+ from langchain.chains import RetrievalQA
13
+ from langchain.prompts import PromptTemplate
14
+ import faiss
15
+ import uuid
16
+ from dotenv import load_dotenv
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+ HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
21
+ RAG_ACCESS_KEY = os.getenv("RAG_ACCESS_KEY")
22
+
23
+ # Initialize session state
24
+ if "vectorstore" not in st.session_state:
25
+ st.session_state.vectorstore = None
26
+ if "history" not in st.session_state:
27
+ st.session_state.history = []
28
+ if "authenticated" not in st.session_state:
29
+ st.session_state.authenticated = False
30
+
31
+ # Sidebar
32
+ with st.sidebar:
33
+ st.header("RAG Control Panel")
34
+ api_key_input = st.text_input("Enter RAG Access Key", type="password")
35
+
36
+ # Authentication
37
+ if st.button("Authenticate"):
38
+ if api_key_input == RAG_ACCESS_KEY:
39
+ st.session_state.authenticated = True
40
+ st.success("Authentication successful!")
41
+ else:
42
+ st.error("Invalid API key.")
43
+
44
+ # File uploader
45
+ if st.session_state.authenticated:
46
+ input_type = st.selectbox("Select Input Type", ["Single PDF", "Folder/Zip of PDFs"])
47
+ input_data = None
48
+ if input_type == "Single PDF":
49
+ input_data = st.file_uploader("Upload a PDF file", type=["pdf"])
50
+ else:
51
+ input_data = st.file_uploader("Upload a folder or zip of PDFs", type=["zip"])
52
+
53
+ if st.button("Process Files") and input_data is not None:
54
+ with st.spinner("Processing files..."):
55
+ vector_store = process_input(input_type, input_data)
56
+ st.session_state.vectorstore = vector_store
57
+ st.success("Files processed successfully. You can now ask questions.")
58
+
59
+ # Display chat history
60
+ st.subheader("Chat History")
61
+ for i, (q, a) in enumerate(st.session_state.history):
62
+ st.write(f"**Q{i+1}:** {q}")
63
+ st.write(f"**A{i+1}:** {a}")
64
+ st.markdown("---")
65
+
66
+ # Main app
67
+ def main():
68
+ st.title("RAG Q&A App with Mistral AI")
69
+
70
+ if not st.session_state.authenticated:
71
+ st.warning("Please authenticate with your API key in the sidebar.")
72
+ return
73
+
74
+ if st.session_state.vectorstore is None:
75
+ st.info("Please upload and process a PDF or folder/zip of PDFs in the sidebar.")
76
+ return
77
+
78
+ query = st.text_input("Enter your question:")
79
+ if st.button("Submit") and query:
80
+ with st.spinner("Generating answer..."):
81
+ answer = answer_question(st.session_state.vectorstore, query)
82
+ st.session_state.history.append((query, answer))
83
+ st.write("**Answer:**", answer)
84
+
85
+ def process_input(input_type, input_data):
86
+ # Create uploads directory
87
+ os.makedirs("uploads", exist_ok=True)
88
+
89
+ documents = ""
90
+ if input_type == "Single PDF":
91
+ pdf_reader = PdfReader(input_data)
92
+ for page in pdf_reader.pages:
93
+ documents += page.extract_text() or ""
94
+ else:
95
+ # Handle zip file
96
+ zip_path = "uploads/uploaded.zip"
97
+ with open(zip_path, "wb") as f:
98
+ f.write(input_data.getvalue())
99
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
100
+ zip_ref.extractall("uploads/extracted")
101
+
102
+ # Process all PDFs in extracted folder
103
+ for root, _, files in os.walk("uploads/extracted"):
104
+ for file in files:
105
+ if file.endswith(".pdf"):
106
+ pdf_path = os.path.join(root, file)
107
+ pdf_reader = PdfReader(pdf_path)
108
+ for page in pdf_reader.pages:
109
+ documents += page.extract_text() or ""
110
+
111
+ # Clean up extracted files
112
+ shutil.rmtree("uploads/extracted", ignore_errors=True)
113
+ os.remove(zip_path)
114
+
115
+ # Split text
116
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
117
+ texts = text_splitter.split_text(documents)
118
+
119
+ # Create embeddings
120
+ hf_embeddings = HuggingFaceEmbeddings(
121
+ model_name="sentence-transformers/all-mpnet-base-v2",
122
+ model_kwargs={'device': 'cpu'}
123
+ )
124
+
125
+ # Initialize FAISS
126
+ dimension = len(hf_embeddings.embed_query("sample text"))
127
+ index = faiss.IndexFlatL2(dimension)
128
+ vector_store = FAISS(
129
+ embedding_function=hf_embeddings,
130
+ index=index,
131
+ docstore=InMemoryDocstore({}),
132
+ index_to_docstore_id={}
133
+ )
134
+
135
+ # Add texts to vector store
136
+ uuids = [str(uuid.uuid4()) for _ in range(len(texts))]
137
+ vector_store.add_texts(texts, ids=uuids)
138
+
139
+ # Save vector store locally
140
+ vector_store.save_local("vectorstore/faiss_index")
141
+
142
+ return vector_store
143
+
144
+ def answer_question(vectorstore, query):
145
+ llm = HuggingFaceHub(
146
+ repo_id="mistralai/Mistral-7B-Instruct-v0.1",
147
+ model_kwargs={"temperature": 0.7, "max_length": 512},
148
+ huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN
149
+ )
150
+
151
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
152
+
153
+ prompt_template = PromptTemplate(
154
+ template="Use the provided context to answer the question concisely:\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:",
155
+ input_variables=["context", "question"]
156
+ )
157
+
158
+ qa_chain = RetrievalQA.from_chain_type(
159
+ llm=llm,
160
+ chain_type="stuff",
161
+ retriever=retriever,
162
+ return_source_documents=False,
163
+ chain_type_kwargs={"prompt": prompt_template}
164
+ )
165
+
166
+ result = qa_chain({"query": query})
167
+ return result["result"].split("Answer:")[-1].strip()
168
+
169
+ if __name__ == "__main__":
170
+ main()