agentic-RAG / app.py
APrmn8's picture
remove error
8a7cf31 verified
raw
history blame
9.75 kB
# app.py
import gradio as gr
import os
import re
import shutil
import torch
from langchain_community.document_loaders import ArxivLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
# --- Configuration ---
ARXIV_DIR = "./arxiv_papers" # Directory to save downloaded papers
CHUNK_SIZE = 500 # Characters per chunk
CHUNK_OVERLAP = 50 # Overlap between chunks
EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2'
LLM_MODEL_NAME = "google/flan-t5-small"
# --- RAGAgent Class ---
class RAGAgent:
def __init__(self):
self.embedding_model = None
self.llm = None
self.vectorstore = None
self.qa_chain = None
self.is_initialized = False
def _load_models(self):
"""Loads the embedding and generation models if not already loaded."""
if self.embedding_model is None:
print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME}...")
self.embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
if self.llm is None:
print(f"Loading LLM Model: {LLM_MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME)
# Determine device for pipeline
device = 0 if torch.cuda.is_available() else -1
# Create a Hugging Face pipeline for text generation
text_generation_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=150, # Set a default max_new_tokens for the pipeline
min_length=20,
num_beams=5,
early_stopping=True,
device=device
)
self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
self.is_initialized = True
def initialize_knowledge_base(self, arxiv_query: str, max_papers: int = 5) -> str:
"""
Initializes the knowledge base by downloading, extracting, and chunking
arXiv papers using LangChain components, then building a FAISS vectorstore.
"""
self._load_models() # Ensure models are loaded first
# Clear existing papers before downloading new ones
if os.path.exists(ARXIV_DIR):
shutil.rmtree(ARXIV_DIR)
os.makedirs(ARXIV_DIR, exist_ok=True)
self.vectorstore = None
self.qa_chain = None
print(f"Searching arXiv for '{arxiv_query}' and downloading up to {max_papers} papers...")
try:
# Use LangChain's ArxivLoader
# ArxivLoader downloads PDFs to a temporary directory by default,
# but we can specify a custom path to ensure cleanup.
# For simplicity, we'll let it download to its default temp dir
# and then process. Or, we can manually download and use PyPDFLoader.
# Let's stick to manual download for better control and consistency with previous code.
# Manual download using arxiv library (as it offers more control over filenames)
search_results = arxiv.Search(
query=arxiv_query,
max_results=max_papers,
sort_by=arxiv.SortCriterion.Relevance,
sort_order=arxiv.SortOrder.Descending
)
pdf_paths = []
for i, result in enumerate(search_results.results()):
try:
safe_title = re.sub(r'[\\/:*?"<>|]', '', result.title)
filename = f"{ARXIV_DIR}/{safe_title[:100]}_{result.arxiv_id}.pdf"
print(f"Downloading paper {i+1}/{max_papers}: {result.title}")
result.download_pdf(filename=filename)
pdf_paths.append(filename)
except Exception as e:
print(f"Could not download {result.title}: {e}")
if not pdf_paths:
return "No papers found or downloaded for the given query. Please try a different query."
# Load documents from downloaded PDFs using PyPDFLoader
all_documents = []
for pdf_path in pdf_paths:
try:
loader = PyPDFLoader(pdf_path)
all_documents.extend(loader.load())
except Exception as e:
print(f"Error loading PDF {pdf_path}: {e}")
if not all_documents:
return "Could not load any documents from downloaded PDFs. Please try a different query or fewer papers."
print(f"Loaded {len(all_documents)} raw documents from PDFs.")
# Split documents into chunks using RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len,
is_separator_regex=False,
)
self.knowledge_base_chunks = text_splitter.split_documents(all_documents)
if not self.knowledge_base_chunks:
return "No meaningful text chunks could be created from the papers after splitting."
print(f"Total chunks created: {len(self.knowledge_base_chunks)}")
# Create FAISS vectorstore from chunks and embeddings
print("Creating FAISS vectorstore from chunks...")
self.vectorstore = FAISS.from_documents(self.knowledge_base_chunks, self.embedding_model)
print(f"FAISS vectorstore created with {len(self.knowledge_base_chunks)} documents.")
# Create RetrievalQA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff", # "stuff" puts all retrieved docs into one prompt
retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}), # Retrieve top 3 docs
return_source_documents=False # Set to True if you want to return source docs
)
return f"Knowledge base loaded with {len(self.knowledge_base_chunks)} chunks from {len(pdf_paths)} arXiv papers on '{arxiv_query}'."
except Exception as e:
print(f"Error during knowledge base initialization: {e}")
return f"An error occurred during knowledge base initialization: {e}"
def query_agent(self, query: str) -> str:
"""
Retrieves relevant information from the knowledge base and generates an answer
using the LangChain RetrievalQA chain.
"""
if not query.strip():
return "Please enter a question."
if not self.is_initialized or self.qa_chain is None:
return "Knowledge base not loaded. Please initialize it by providing an arXiv query."
print(f"\n--- Querying LLM with LangChain QA Chain ---\nQuestion: {query}\n----------------------")
try:
# Use the RetrievalQA chain to get the answer
result = self.qa_chain.invoke({"query": query})
answer = result["result"].strip()
except Exception as e:
print(f"Error during generation: {e}")
answer = "I apologize, but I encountered an error while generating the answer. Please try again or rephrase your question."
return answer
# --- Gradio Interface ---
# Instantiate the RAGAgent
rag_agent_instance = RAGAgent()
print("Setting up Gradio interface...")
with gr.Blocks() as demo:
gr.Markdown("# 📚 Educational RAG Agent with arXiv Knowledge Base (LangChain)")
gr.Markdown("First, load a knowledge base by specifying an arXiv search query. Then, ask questions!")
with gr.Row():
arxiv_input = gr.Textbox(
label="arXiv Search Query (e.g., 'Large Language Models', 'Reinforcement Learning')",
placeholder="Enter a topic to search for papers on arXiv...",
lines=1
)
max_papers_slider = gr.Slider(
minimum=1,
maximum=10,
step=1,
value=3,
label="Max Papers to Download"
)
load_kb_button = gr.Button("Load Knowledge Base from arXiv")
kb_status_output = gr.Textbox(label="Knowledge Base Status", interactive=False)
with gr.Row():
question_input = gr.Textbox(
lines=3,
placeholder="Ask a question based on the loaded arXiv papers...",
label="Your Question"
)
answer_output = gr.Textbox(label="Answer", lines=7, interactive=False)
submit_button = gr.Button("Get Answer")
load_kb_button.click(
fn=rag_agent_instance.initialize_knowledge_base, # Call method of instance
inputs=[arxiv_input, max_papers_slider],
outputs=kb_status_output
)
submit_button.click(
fn=rag_agent_instance.query_agent, # Call method of instance
inputs=question_input,
outputs=answer_output
)
gr.Examples(
examples=[
["What is the transformer architecture?"],
["Explain attention mechanisms in deep learning."],
["What are the challenges in reinforcement learning?"],
],
inputs=question_input
)
# Launch the Gradio app
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(share=False)