Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import os | |
import re | |
import shutil | |
import torch | |
from langchain_community.document_loaders import ArxivLoader, PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
# --- Configuration --- | |
ARXIV_DIR = "./arxiv_papers" # Directory to save downloaded papers | |
CHUNK_SIZE = 500 # Characters per chunk | |
CHUNK_OVERLAP = 50 # Overlap between chunks | |
EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' | |
LLM_MODEL_NAME = "google/flan-t5-small" | |
# --- RAGAgent Class --- | |
class RAGAgent: | |
def __init__(self): | |
self.embedding_model = None | |
self.llm = None | |
self.vectorstore = None | |
self.qa_chain = None | |
self.is_initialized = False | |
def _load_models(self): | |
"""Loads the embedding and generation models if not already loaded.""" | |
if self.embedding_model is None: | |
print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME}...") | |
self.embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) | |
if self.llm is None: | |
print(f"Loading LLM Model: {LLM_MODEL_NAME}...") | |
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME) | |
# Determine device for pipeline | |
device = 0 if torch.cuda.is_available() else -1 | |
# Create a Hugging Face pipeline for text generation | |
text_generation_pipeline = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=150, # Set a default max_new_tokens for the pipeline | |
min_length=20, | |
num_beams=5, | |
early_stopping=True, | |
device=device | |
) | |
self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline) | |
self.is_initialized = True | |
def initialize_knowledge_base(self, arxiv_query: str, max_papers: int = 5) -> str: | |
""" | |
Initializes the knowledge base by downloading, extracting, and chunking | |
arXiv papers using LangChain components, then building a FAISS vectorstore. | |
""" | |
self._load_models() # Ensure models are loaded first | |
# Clear existing papers before downloading new ones | |
if os.path.exists(ARXIV_DIR): | |
shutil.rmtree(ARXIV_DIR) | |
os.makedirs(ARXIV_DIR, exist_ok=True) | |
self.vectorstore = None | |
self.qa_chain = None | |
print(f"Searching arXiv for '{arxiv_query}' and downloading up to {max_papers} papers...") | |
try: | |
# Use LangChain's ArxivLoader | |
# ArxivLoader downloads PDFs to a temporary directory by default, | |
# but we can specify a custom path to ensure cleanup. | |
# For simplicity, we'll let it download to its default temp dir | |
# and then process. Or, we can manually download and use PyPDFLoader. | |
# Let's stick to manual download for better control and consistency with previous code. | |
# Manual download using arxiv library (as it offers more control over filenames) | |
search_results = arxiv.Search( | |
query=arxiv_query, | |
max_results=max_papers, | |
sort_by=arxiv.SortCriterion.Relevance, | |
sort_order=arxiv.SortOrder.Descending | |
) | |
pdf_paths = [] | |
for i, result in enumerate(search_results.results()): | |
try: | |
safe_title = re.sub(r'[\\/:*?"<>|]', '', result.title) | |
filename = f"{ARXIV_DIR}/{safe_title[:100]}_{result.arxiv_id}.pdf" | |
print(f"Downloading paper {i+1}/{max_papers}: {result.title}") | |
result.download_pdf(filename=filename) | |
pdf_paths.append(filename) | |
except Exception as e: | |
print(f"Could not download {result.title}: {e}") | |
if not pdf_paths: | |
return "No papers found or downloaded for the given query. Please try a different query." | |
# Load documents from downloaded PDFs using PyPDFLoader | |
all_documents = [] | |
for pdf_path in pdf_paths: | |
try: | |
loader = PyPDFLoader(pdf_path) | |
all_documents.extend(loader.load()) | |
except Exception as e: | |
print(f"Error loading PDF {pdf_path}: {e}") | |
if not all_documents: | |
return "Could not load any documents from downloaded PDFs. Please try a different query or fewer papers." | |
print(f"Loaded {len(all_documents)} raw documents from PDFs.") | |
# Split documents into chunks using RecursiveCharacterTextSplitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=CHUNK_OVERLAP, | |
length_function=len, | |
is_separator_regex=False, | |
) | |
self.knowledge_base_chunks = text_splitter.split_documents(all_documents) | |
if not self.knowledge_base_chunks: | |
return "No meaningful text chunks could be created from the papers after splitting." | |
print(f"Total chunks created: {len(self.knowledge_base_chunks)}") | |
# Create FAISS vectorstore from chunks and embeddings | |
print("Creating FAISS vectorstore from chunks...") | |
self.vectorstore = FAISS.from_documents(self.knowledge_base_chunks, self.embedding_model) | |
print(f"FAISS vectorstore created with {len(self.knowledge_base_chunks)} documents.") | |
# Create RetrievalQA chain | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", # "stuff" puts all retrieved docs into one prompt | |
retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}), # Retrieve top 3 docs | |
return_source_documents=False # Set to True if you want to return source docs | |
) | |
return f"Knowledge base loaded with {len(self.knowledge_base_chunks)} chunks from {len(pdf_paths)} arXiv papers on '{arxiv_query}'." | |
except Exception as e: | |
print(f"Error during knowledge base initialization: {e}") | |
return f"An error occurred during knowledge base initialization: {e}" | |
def query_agent(self, query: str) -> str: | |
""" | |
Retrieves relevant information from the knowledge base and generates an answer | |
using the LangChain RetrievalQA chain. | |
""" | |
if not query.strip(): | |
return "Please enter a question." | |
if not self.is_initialized or self.qa_chain is None: | |
return "Knowledge base not loaded. Please initialize it by providing an arXiv query." | |
print(f"\n--- Querying LLM with LangChain QA Chain ---\nQuestion: {query}\n----------------------") | |
try: | |
# Use the RetrievalQA chain to get the answer | |
result = self.qa_chain.invoke({"query": query}) | |
answer = result["result"].strip() | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
answer = "I apologize, but I encountered an error while generating the answer. Please try again or rephrase your question." | |
return answer | |
# --- Gradio Interface --- | |
# Instantiate the RAGAgent | |
rag_agent_instance = RAGAgent() | |
print("Setting up Gradio interface...") | |
with gr.Blocks() as demo: | |
gr.Markdown("# 📚 Educational RAG Agent with arXiv Knowledge Base (LangChain)") | |
gr.Markdown("First, load a knowledge base by specifying an arXiv search query. Then, ask questions!") | |
with gr.Row(): | |
arxiv_input = gr.Textbox( | |
label="arXiv Search Query (e.g., 'Large Language Models', 'Reinforcement Learning')", | |
placeholder="Enter a topic to search for papers on arXiv...", | |
lines=1 | |
) | |
max_papers_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=3, | |
label="Max Papers to Download" | |
) | |
load_kb_button = gr.Button("Load Knowledge Base from arXiv") | |
kb_status_output = gr.Textbox(label="Knowledge Base Status", interactive=False) | |
with gr.Row(): | |
question_input = gr.Textbox( | |
lines=3, | |
placeholder="Ask a question based on the loaded arXiv papers...", | |
label="Your Question" | |
) | |
answer_output = gr.Textbox(label="Answer", lines=7, interactive=False) | |
submit_button = gr.Button("Get Answer") | |
load_kb_button.click( | |
fn=rag_agent_instance.initialize_knowledge_base, # Call method of instance | |
inputs=[arxiv_input, max_papers_slider], | |
outputs=kb_status_output | |
) | |
submit_button.click( | |
fn=rag_agent_instance.query_agent, # Call method of instance | |
inputs=question_input, | |
outputs=answer_output | |
) | |
gr.Examples( | |
examples=[ | |
["What is the transformer architecture?"], | |
["Explain attention mechanisms in deep learning."], | |
["What are the challenges in reinforcement learning?"], | |
], | |
inputs=question_input | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
print("Launching Gradio app...") | |
demo.launch(share=False) |