Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import os | |
import re | |
import shutil | |
import torch | |
import pickle # For saving/loading Python objects | |
# LangChain imports | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
# --- Configuration --- | |
ARXIV_DIR = "./arxiv_papers" # Directory to save downloaded papers | |
KB_STORAGE_DIR = "./knowledge_base_storage" # Directory to save/load KB | |
FAISS_INDEX_PATH = os.path.join(KB_STORAGE_DIR, "faiss_index.bin") | |
CHUNKS_PATH = os.path.join(KB_STORAGE_DIR, "knowledge_base_chunks.pkl") | |
CHUNK_SIZE = 500 # Characters per chunk | |
CHUNK_OVERLAP = 50 # Overlap between chunks | |
EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' | |
LLM_MODEL_NAME = "google/flan-t5-small" | |
# Ensure KB storage directory exists | |
os.makedirs(KB_STORAGE_DIR, exist_ok=True) | |
# --- Helper Functions for arXiv and PDF Processing --- | |
def clean_text(text: str) -> str: | |
"""Basic text cleaning: replaces multiple spaces/newlines with single space and strips whitespace.""" | |
text = re.sub(r'\s+', ' ', text) | |
text = text.strip() | |
return text | |
def get_arxiv_papers(query: str, max_papers: int = 5) -> list[str]: | |
""" | |
Searches arXiv for papers, downloads their PDFs, and returns a list of file paths. | |
Clears the ARXIV_DIR before downloading new papers. | |
""" | |
# Clear existing papers before downloading new ones | |
if os.path.exists(ARXIV_DIR): | |
shutil.rmtree(ARXIV_DIR) | |
os.makedirs(ARXIV_DIR, exist_ok=True) | |
print(f"Searching arXiv for '{query}' and downloading up to {max_papers} papers...") | |
import arxiv # Import here to ensure it's available when this function is called | |
search_results = arxiv.Search( | |
query=query, | |
max_results=max_papers, | |
sort_by=arxiv.SortCriterion.Relevance, | |
sort_order=arxiv.SortOrder.Descending | |
) | |
downloaded_files = [] | |
for i, result in enumerate(search_results.results()): | |
try: | |
# Create a safe filename | |
safe_title = re.sub(r'[\\/:*?"<>|]', '', result.title) # Remove invalid characters | |
filename = f"{ARXIV_DIR}/{safe_title[:100]}_{result.arxiv_id}.pdf" # Limit title length | |
print(f"Downloading paper {i+1}/{max_papers}: {result.title}") | |
result.download_pdf(filename=filename) | |
downloaded_files.append(filename) | |
except Exception as e: | |
print(f"Could not download {result.title}: {e}") | |
return downloaded_files | |
# --- RAGAgent Class --- | |
class RAGAgent: | |
def __init__(self): | |
self.embedding_model = None | |
self.llm = None | |
self.vectorstore = None | |
self.qa_chain = None | |
self.is_initialized = False | |
def _load_models(self): | |
"""Loads the embedding and generation models if not already loaded.""" | |
if self.embedding_model is None: | |
print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME}...") | |
self.embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
if self.llm is None: | |
print(f"Loading LLM Model: {LLM_MODEL_NAME}...") | |
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME) | |
# Determine device for pipeline | |
device = 0 if torch.cuda.is_available() else -1 | |
# Create a Hugging Face pipeline for text generation | |
text_generation_pipeline = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=150, # Set a default max_new_tokens for the pipeline | |
min_length=20, | |
num_beams=5, | |
early_stopping=True, | |
device=device | |
) | |
self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline) | |
self.is_initialized = True | |
def initialize_knowledge_base(self, arxiv_query: str, max_papers: int = 5) -> str: | |
""" | |
Initializes the knowledge base by downloading, extracting, and chunking | |
arXiv papers using LangChain components, then building a FAISS vectorstore. | |
""" | |
self._load_models() # Ensure models are loaded first | |
# Clear existing papers before downloading new ones | |
if os.path.exists(ARXIV_DIR): | |
shutil.rmtree(ARXIV_DIR) | |
os.makedirs(ARXIV_DIR, exist_ok=True) | |
self.vectorstore = None | |
self.qa_chain = None | |
self.knowledge_base_chunks = [] # Reset chunks | |
print(f"Searching arXiv for '{arxiv_query}' and downloading up to {max_papers} papers...") | |
try: | |
# Manual download using arxiv library (as it offers more control over filenames) | |
pdf_paths = get_arxiv_papers(arxiv_query, max_papers) # Call the helper function | |
if not pdf_paths: | |
return "No papers found or downloaded for the given query. Please try a different query." | |
# Load documents from downloaded PDFs using PyPDFLoader | |
all_documents = [] | |
for pdf_path in pdf_paths: | |
try: | |
loader = PyPDFLoader(pdf_path) | |
all_documents.extend(loader.load()) | |
except Exception as e: | |
print(f"Error loading PDF {pdf_path}: {e}") | |
if not all_documents: | |
return "Could not load any documents from downloaded PDFs. Please try a different query or fewer papers." | |
print(f"Loaded {len(all_documents)} raw documents from PDFs.") | |
# Split documents into chunks using RecursiveCharacterTextSplitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=CHUNK_OVERLAP, | |
length_function=len, | |
is_separator_regex=False, | |
) | |
self.knowledge_base_chunks = text_splitter.split_documents(all_documents) | |
if not self.knowledge_base_chunks: | |
return "No meaningful text chunks could be created from the papers after splitting." | |
print(f"Total chunks created: {len(self.knowledge_base_chunks)}") | |
# Create FAISS vectorstore from chunks and embeddings | |
print("Creating FAISS vectorstore from chunks...") | |
self.vectorstore = FAISS.from_documents(self.knowledge_base_chunks, self.embedding_model) | |
print(f"FAISS vectorstore created with {len(self.knowledge_base_chunks)} documents.") | |
# Create RetrievalQA chain | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", # "stuff" puts all retrieved docs into one prompt | |
retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}), # Retrieve top 3 docs | |
return_source_documents=False # Set to True if you want to return source docs | |
) | |
return f"Knowledge base loaded with {len(self.knowledge_base_chunks)} chunks from {len(pdf_paths)} arXiv papers on '{arxiv_query}'." | |
except Exception as e: | |
print(f"Error during knowledge base initialization: {e}") | |
return f"An error occurred during knowledge base initialization: {e}" | |
def save_knowledge_base(self) -> str: | |
"""Saves the current FAISS vectorstore and knowledge base chunks to disk.""" | |
if not self.vectorstore or not self.knowledge_base_chunks: | |
return "No knowledge base to save. Please load one first." | |
try: | |
# Save FAISS index | |
self.vectorstore.save_local(KB_STORAGE_DIR, index_name="faiss_index") | |
# Save chunks (metadata for FAISS, or for re-building if needed) | |
with open(CHUNKS_PATH, 'wb') as f: | |
pickle.dump(self.knowledge_base_chunks, f) | |
print(f"Knowledge base saved to {KB_STORAGE_DIR}") | |
return f"Knowledge base saved successfully to {KB_STORAGE_DIR}." | |
except Exception as e: | |
print(f"Error saving knowledge base: {e}") | |
return f"Error saving knowledge base: {e}" | |
def load_knowledge_base(self) -> str: | |
"""Loads the FAISS vectorstore and knowledge base chunks from disk.""" | |
self._load_models() # Ensure models are loaded before loading KB | |
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(CHUNKS_PATH): | |
return "Saved knowledge base not found. Please load or create one first." | |
try: | |
# Load FAISS index | |
self.vectorstore = FAISS.load_local(KB_STORAGE_DIR, self.embedding_model, index_name="faiss_index", allow_dangerous_deserialization=True) | |
# Load chunks | |
with open(CHUNKS_PATH, 'rb') as f: | |
self.knowledge_base_chunks = pickle.load(f) | |
# Re-create RetrievalQA chain after loading vectorstore | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}), | |
return_source_documents=False | |
) | |
print(f"Knowledge base loaded from {KB_STORAGE_DIR}") | |
return f"Knowledge base loaded successfully from {KB_STORAGE_DIR} with {len(self.knowledge_base_chunks)} chunks." | |
except Exception as e: | |
print(f"Error loading knowledge base: {e}") | |
self.vectorstore = None | |
self.qa_chain = None | |
self.knowledge_base_chunks = [] | |
return f"Error loading knowledge base: {e}" | |
def query_agent(self, query: str) -> str: | |
""" | |
Retrieves relevant information from the knowledge base and generates an answer | |
using the LangChain RetrievalQA chain. | |
""" | |
if not query.strip(): | |
return "Please enter a question." | |
if not self.is_initialized or self.qa_chain is None: | |
return "Knowledge base not loaded. Please initialize it by providing an arXiv query or loading from disk." | |
print(f"\n--- Querying LLM with LangChain QA Chain ---\nQuestion: {query}\n----------------------") | |
try: | |
# Use the RetrievalQA chain to get the answer | |
result = self.qa_chain.invoke({"query": query}) | |
answer = result["result"].strip() | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
answer = "I apologize, but I encountered an error while generating the answer. Please try again or rephrase your question." | |
return answer | |
# --- Gradio Interface --- | |
# Instantiate the RAGAgent | |
rag_agent_instance = RAGAgent() | |
print("Setting up Gradio interface...") | |
with gr.Blocks() as demo: | |
gr.Markdown("# 📚 Educational RAG Agent with Persistent Knowledge Base") | |
gr.Markdown("First, load a knowledge base from arXiv, then you can save it or load a previously saved one. Finally, ask questions!") | |
with gr.Row(): | |
arxiv_input = gr.Textbox( | |
label="arXiv Search Query (e.g., 'Large Language Models', 'Reinforcement Learning')", | |
placeholder="Enter a topic to search for papers on arXiv...", | |
lines=1 | |
) | |
max_papers_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=3, | |
label="Max Papers to Download" | |
) | |
load_kb_from_arxiv_button = gr.Button("Load KB from arXiv") | |
kb_status_output = gr.Textbox(label="Knowledge Base Status", interactive=False) | |
with gr.Row(): | |
save_kb_button = gr.Button("Save Knowledge Base to Disk") | |
load_kb_from_disk_button = gr.Button("Load Knowledge Base from Disk") | |
with gr.Row(): | |
question_input = gr.Textbox( | |
lines=3, | |
placeholder="Ask a question based on the loaded knowledge base...", | |
label="Your Question" | |
) | |
answer_output = gr.Textbox(label="Answer", lines=7, interactive=False) | |
submit_button = gr.Button("Get Answer") | |
load_kb_from_arxiv_button.click( | |
fn=rag_agent_instance.initialize_knowledge_base, | |
inputs=[arxiv_input, max_papers_slider], | |
outputs=kb_status_output | |
) | |
save_kb_button.click( | |
fn=rag_agent_instance.save_knowledge_base, | |
inputs=[], | |
outputs=kb_status_output | |
) | |
load_kb_from_disk_button.click( | |
fn=rag_agent_instance.load_knowledge_base, | |
inputs=[], | |
outputs=kb_status_output | |
) | |
submit_button.click( | |
fn=rag_agent_instance.query_agent, | |
inputs=question_input, | |
outputs=answer_output | |
) | |
gr.Examples( | |
examples=[ | |
["What is the transformer architecture?"], | |
["Explain attention mechanisms in deep learning."], | |
["What are the challenges in reinforcement learning?"], | |
], | |
inputs=question_input | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
print("Launching Gradio app...") | |
demo.launch(share=False) |