Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import zipfile | |
import shutil | |
from io import BytesIO | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
from langchain_community.llms import HuggingFaceHub | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
import faiss | |
import uuid | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
RAG_ACCESS_KEY = os.getenv("RAG_ACCESS_KEY") | |
# Initialize session state | |
if "vectorstore" not in st.session_state: | |
st.session_state.vectorstore = None | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
# Sidebar | |
with st.sidebar: | |
st.header("RAG Control Panel") | |
api_key_input = st.text_input("Enter RAG Access Key", type="password") | |
# Authentication | |
if st.button("Authenticate"): | |
if api_key_input == RAG_ACCESS_KEY: | |
st.session_state.authenticated = True | |
st.success("Authentication successful!") | |
else: | |
st.error("Invalid API key.") | |
# File uploader | |
if st.session_state.authenticated: | |
input_type = st.selectbox("Select Input Type", ["Single PDF", "Folder/Zip of PDFs"]) | |
input_data = None | |
if input_type == "Single PDF": | |
input_data = st.file_uploader("Upload a PDF file", type=["pdf"]) | |
else: | |
input_data = st.file_uploader("Upload a folder or zip of PDFs", type=["zip"]) | |
if st.button("Process Files") and input_data is not None: | |
with st.spinner("Processing files..."): | |
vector_store = process_input(input_type, input_data) | |
st.session_state.vectorstore = vector_store | |
st.success("Files processed successfully. You can now ask questions.") | |
# Display chat history | |
st.subheader("Chat History") | |
for i, (q, a) in enumerate(st.session_state.history): | |
st.write(f"**Q{i+1}:** {q}") | |
st.write(f"**A{i+1}:** {a}") | |
st.markdown("---") | |
# Main app | |
def main(): | |
st.title("RAG Q&A App with Mistral AI") | |
if not st.session_state.authenticated: | |
st.warning("Please authenticate with your API key in the sidebar.") | |
return | |
if st.session_state.vectorstore is None: | |
st.info("Please upload and process a PDF or folder/zip of PDFs in the sidebar.") | |
return | |
query = st.text_input("Enter your question:") | |
if st.button("Submit") and query: | |
with st.spinner("Generating answer..."): | |
answer = answer_question(st.session_state.vectorstore, query) | |
st.session_state.history.append((query, answer)) | |
st.write("**Answer:**", answer) | |
def process_input(input_type, input_data): | |
# Create uploads directory | |
os.makedirs("uploads", exist_ok=True) | |
documents = "" | |
if input_type == "Single PDF": | |
pdf_reader = PdfReader(input_data) | |
for page in pdf_reader.pages: | |
documents += page.extract_text() or "" | |
else: | |
# Handle zip file | |
zip_path = "uploads/uploaded.zip" | |
with open(zip_path, "wb") as f: | |
f.write(input_data.getvalue()) | |
with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
zip_ref.extractall("uploads/extracted") | |
# Process all PDFs in extracted folder | |
for root, _, files in os.walk("uploads/extracted"): | |
for file in files: | |
if file.endswith(".pdf"): | |
pdf_path = os.path.join(root, file) | |
pdf_reader = PdfReader(pdf_path) | |
for page in pdf_reader.pages: | |
documents += page.extract_text() or "" | |
# Clean up extracted files | |
shutil.rmtree("uploads/extracted", ignore_errors=True) | |
os.remove(zip_path) | |
# Split text | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
texts = text_splitter.split_text(documents) | |
# Create embeddings | |
hf_embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
model_kwargs={'device': 'cpu'} | |
) | |
# Initialize FAISS | |
dimension = len(hf_embeddings.embed_query("sample text")) | |
index = faiss.IndexFlatL2(dimension) | |
vector_store = FAISS( | |
embedding_function=hf_embeddings, | |
index=index, | |
docstore=InMemoryDocstore({}), | |
index_to_docstore_id={} | |
) | |
# Add texts to vector store | |
uuids = [str(uuid.uuid4()) for _ in range(len(texts))] | |
vector_store.add_texts(texts, ids=uuids) | |
# Save vector store locally | |
vector_store.save_local("vectorstore/faiss_index") | |
return vector_store | |
def answer_question(vectorstore, query): | |
llm = HuggingFaceHub( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.1", | |
model_kwargs={"temperature": 0.7, "max_length": 512}, | |
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN | |
) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
prompt_template = PromptTemplate( | |
template="Use the provided context to answer the question concisely:\n\nContext: {context}\n\nQuestion: {question}\n\nAnswer:", | |
input_variables=["context", "question"] | |
) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=False, | |
chain_type_kwargs={"prompt": prompt_template} | |
) | |
result = qa_chain({"query": query}) | |
return result["result"].split("Answer:")[-1].strip() | |
if __name__ == "__main__": | |
main() | |