agentic-RAG / app.py
APrmn8's picture
revise app
08a44c9 verified
raw
history blame
13.5 kB
# app.py
import gradio as gr
import os
import re
import shutil
import torch
import pickle # For saving/loading Python objects
# LangChain imports
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
# --- Configuration ---
ARXIV_DIR = "./arxiv_papers" # Directory to save downloaded papers
KB_STORAGE_DIR = "./knowledge_base_storage" # Directory to save/load KB
FAISS_INDEX_PATH = os.path.join(KB_STORAGE_DIR, "faiss_index.bin")
CHUNKS_PATH = os.path.join(KB_STORAGE_DIR, "knowledge_base_chunks.pkl")
CHUNK_SIZE = 500 # Characters per chunk
CHUNK_OVERLAP = 50 # Overlap between chunks
EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2'
LLM_MODEL_NAME = "google/flan-t5-small"
# Ensure KB storage directory exists
os.makedirs(KB_STORAGE_DIR, exist_ok=True)
# --- Helper Functions for arXiv and PDF Processing ---
def clean_text(text: str) -> str:
"""Basic text cleaning: replaces multiple spaces/newlines with single space and strips whitespace."""
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def get_arxiv_papers(query: str, max_papers: int = 5) -> list[str]:
"""
Searches arXiv for papers, downloads their PDFs, and returns a list of file paths.
Clears the ARXIV_DIR before downloading new papers.
"""
# Clear existing papers before downloading new ones
if os.path.exists(ARXIV_DIR):
shutil.rmtree(ARXIV_DIR)
os.makedirs(ARXIV_DIR, exist_ok=True)
print(f"Searching arXiv for '{query}' and downloading up to {max_papers} papers...")
import arxiv # Import here to ensure it's available when this function is called
search_results = arxiv.Search(
query=query,
max_results=max_papers,
sort_by=arxiv.SortCriterion.Relevance,
sort_order=arxiv.SortOrder.Descending
)
downloaded_files = []
for i, result in enumerate(search_results.results()):
try:
# Create a safe filename
safe_title = re.sub(r'[\\/:*?"<>|]', '', result.title) # Remove invalid characters
filename = f"{ARXIV_DIR}/{safe_title[:100]}_{result.arxiv_id}.pdf" # Limit title length
print(f"Downloading paper {i+1}/{max_papers}: {result.title}")
result.download_pdf(filename=filename)
downloaded_files.append(filename)
except Exception as e:
print(f"Could not download {result.title}: {e}")
return downloaded_files
# --- RAGAgent Class ---
class RAGAgent:
def __init__(self):
self.embedding_model = None
self.llm = None
self.vectorstore = None
self.qa_chain = None
self.is_initialized = False
def _load_models(self):
"""Loads the embedding and generation models if not already loaded."""
if self.embedding_model is None:
print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME}...")
self.embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
if self.llm is None:
print(f"Loading LLM Model: {LLM_MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME)
# Determine device for pipeline
device = 0 if torch.cuda.is_available() else -1
# Create a Hugging Face pipeline for text generation
text_generation_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=150, # Set a default max_new_tokens for the pipeline
min_length=20,
num_beams=5,
early_stopping=True,
device=device
)
self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
self.is_initialized = True
def initialize_knowledge_base(self, arxiv_query: str, max_papers: int = 5) -> str:
"""
Initializes the knowledge base by downloading, extracting, and chunking
arXiv papers using LangChain components, then building a FAISS vectorstore.
"""
self._load_models() # Ensure models are loaded first
# Clear existing papers before downloading new ones
if os.path.exists(ARXIV_DIR):
shutil.rmtree(ARXIV_DIR)
os.makedirs(ARXIV_DIR, exist_ok=True)
self.vectorstore = None
self.qa_chain = None
self.knowledge_base_chunks = [] # Reset chunks
print(f"Searching arXiv for '{arxiv_query}' and downloading up to {max_papers} papers...")
try:
# Manual download using arxiv library (as it offers more control over filenames)
pdf_paths = get_arxiv_papers(arxiv_query, max_papers) # Call the helper function
if not pdf_paths:
return "No papers found or downloaded for the given query. Please try a different query."
# Load documents from downloaded PDFs using PyPDFLoader
all_documents = []
for pdf_path in pdf_paths:
try:
loader = PyPDFLoader(pdf_path)
all_documents.extend(loader.load())
except Exception as e:
print(f"Error loading PDF {pdf_path}: {e}")
if not all_documents:
return "Could not load any documents from downloaded PDFs. Please try a different query or fewer papers."
print(f"Loaded {len(all_documents)} raw documents from PDFs.")
# Split documents into chunks using RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len,
is_separator_regex=False,
)
self.knowledge_base_chunks = text_splitter.split_documents(all_documents)
if not self.knowledge_base_chunks:
return "No meaningful text chunks could be created from the papers after splitting."
print(f"Total chunks created: {len(self.knowledge_base_chunks)}")
# Create FAISS vectorstore from chunks and embeddings
print("Creating FAISS vectorstore from chunks...")
self.vectorstore = FAISS.from_documents(self.knowledge_base_chunks, self.embedding_model)
print(f"FAISS vectorstore created with {len(self.knowledge_base_chunks)} documents.")
# Create RetrievalQA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff", # "stuff" puts all retrieved docs into one prompt
retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}), # Retrieve top 3 docs
return_source_documents=False # Set to True if you want to return source docs
)
return f"Knowledge base loaded with {len(self.knowledge_base_chunks)} chunks from {len(pdf_paths)} arXiv papers on '{arxiv_query}'."
except Exception as e:
print(f"Error during knowledge base initialization: {e}")
return f"An error occurred during knowledge base initialization: {e}"
def save_knowledge_base(self) -> str:
"""Saves the current FAISS vectorstore and knowledge base chunks to disk."""
if not self.vectorstore or not self.knowledge_base_chunks:
return "No knowledge base to save. Please load one first."
try:
# Save FAISS index
self.vectorstore.save_local(KB_STORAGE_DIR, index_name="faiss_index")
# Save chunks (metadata for FAISS, or for re-building if needed)
with open(CHUNKS_PATH, 'wb') as f:
pickle.dump(self.knowledge_base_chunks, f)
print(f"Knowledge base saved to {KB_STORAGE_DIR}")
return f"Knowledge base saved successfully to {KB_STORAGE_DIR}."
except Exception as e:
print(f"Error saving knowledge base: {e}")
return f"Error saving knowledge base: {e}"
def load_knowledge_base(self) -> str:
"""Loads the FAISS vectorstore and knowledge base chunks from disk."""
self._load_models() # Ensure models are loaded before loading KB
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(CHUNKS_PATH):
return "Saved knowledge base not found. Please load or create one first."
try:
# Load FAISS index
self.vectorstore = FAISS.load_local(KB_STORAGE_DIR, self.embedding_model, index_name="faiss_index", allow_dangerous_deserialization=True)
# Load chunks
with open(CHUNKS_PATH, 'rb') as f:
self.knowledge_base_chunks = pickle.load(f)
# Re-create RetrievalQA chain after loading vectorstore
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}),
return_source_documents=False
)
print(f"Knowledge base loaded from {KB_STORAGE_DIR}")
return f"Knowledge base loaded successfully from {KB_STORAGE_DIR} with {len(self.knowledge_base_chunks)} chunks."
except Exception as e:
print(f"Error loading knowledge base: {e}")
self.vectorstore = None
self.qa_chain = None
self.knowledge_base_chunks = []
return f"Error loading knowledge base: {e}"
def query_agent(self, query: str) -> str:
"""
Retrieves relevant information from the knowledge base and generates an answer
using the LangChain RetrievalQA chain.
"""
if not query.strip():
return "Please enter a question."
if not self.is_initialized or self.qa_chain is None:
return "Knowledge base not loaded. Please initialize it by providing an arXiv query or loading from disk."
print(f"\n--- Querying LLM with LangChain QA Chain ---\nQuestion: {query}\n----------------------")
try:
# Use the RetrievalQA chain to get the answer
result = self.qa_chain.invoke({"query": query})
answer = result["result"].strip()
except Exception as e:
print(f"Error during generation: {e}")
answer = "I apologize, but I encountered an error while generating the answer. Please try again or rephrase your question."
return answer
# --- Gradio Interface ---
# Instantiate the RAGAgent
rag_agent_instance = RAGAgent()
print("Setting up Gradio interface...")
with gr.Blocks() as demo:
gr.Markdown("# 📚 Educational RAG Agent with Persistent Knowledge Base")
gr.Markdown("First, load a knowledge base from arXiv, then you can save it or load a previously saved one. Finally, ask questions!")
with gr.Row():
arxiv_input = gr.Textbox(
label="arXiv Search Query (e.g., 'Large Language Models', 'Reinforcement Learning')",
placeholder="Enter a topic to search for papers on arXiv...",
lines=1
)
max_papers_slider = gr.Slider(
minimum=1,
maximum=10,
step=1,
value=3,
label="Max Papers to Download"
)
load_kb_from_arxiv_button = gr.Button("Load KB from arXiv")
kb_status_output = gr.Textbox(label="Knowledge Base Status", interactive=False)
with gr.Row():
save_kb_button = gr.Button("Save Knowledge Base to Disk")
load_kb_from_disk_button = gr.Button("Load Knowledge Base from Disk")
with gr.Row():
question_input = gr.Textbox(
lines=3,
placeholder="Ask a question based on the loaded knowledge base...",
label="Your Question"
)
answer_output = gr.Textbox(label="Answer", lines=7, interactive=False)
submit_button = gr.Button("Get Answer")
load_kb_from_arxiv_button.click(
fn=rag_agent_instance.initialize_knowledge_base,
inputs=[arxiv_input, max_papers_slider],
outputs=kb_status_output
)
save_kb_button.click(
fn=rag_agent_instance.save_knowledge_base,
inputs=[],
outputs=kb_status_output
)
load_kb_from_disk_button.click(
fn=rag_agent_instance.load_knowledge_base,
inputs=[],
outputs=kb_status_output
)
submit_button.click(
fn=rag_agent_instance.query_agent,
inputs=question_input,
outputs=answer_output
)
gr.Examples(
examples=[
["What is the transformer architecture?"],
["Explain attention mechanisms in deep learning."],
["What are the challenges in reinforcement learning?"],
],
inputs=question_input
)
# Launch the Gradio app
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(share=False)