# app.py import gradio as gr import os import re import shutil import torch from langchain_community.document_loaders import ArxivLoader, PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain_community.llms import HuggingFacePipeline from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM # --- Configuration --- ARXIV_DIR = "./arxiv_papers" # Directory to save downloaded papers CHUNK_SIZE = 500 # Characters per chunk CHUNK_OVERLAP = 50 # Overlap between chunks EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' LLM_MODEL_NAME = "google/flan-t5-small" # --- RAGAgent Class --- class RAGAgent: def __init__(self): self.embedding_model = None self.llm = None self.vectorstore = None self.qa_chain = None self.is_initialized = False def _load_models(self): """Loads the embedding and generation models if not already loaded.""" if self.embedding_model is None: print(f"Loading Embedding Model: {EMBEDDING_MODEL_NAME}...") self.embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) if self.llm is None: print(f"Loading LLM Model: {LLM_MODEL_NAME}...") tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL_NAME) # Determine device for pipeline device = 0 if torch.cuda.is_available() else -1 # Create a Hugging Face pipeline for text generation text_generation_pipeline = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, # Set a default max_new_tokens for the pipeline min_length=20, num_beams=5, early_stopping=True, device=device ) self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline) self.is_initialized = True def initialize_knowledge_base(self, arxiv_query: str, max_papers: int = 5) -> str: """ Initializes the knowledge base by downloading, extracting, and chunking arXiv papers using LangChain components, then building a FAISS vectorstore. """ self._load_models() # Ensure models are loaded first # Clear existing papers before downloading new ones if os.path.exists(ARXIV_DIR): shutil.rmtree(ARXIV_DIR) os.makedirs(ARXIV_DIR, exist_ok=True) self.vectorstore = None self.qa_chain = None print(f"Searching arXiv for '{arxiv_query}' and downloading up to {max_papers} papers...") try: # Use LangChain's ArxivLoader # ArxivLoader downloads PDFs to a temporary directory by default, # but we can specify a custom path to ensure cleanup. # For simplicity, we'll let it download to its default temp dir # and then process. Or, we can manually download and use PyPDFLoader. # Let's stick to manual download for better control and consistency with previous code. # Manual download using arxiv library (as it offers more control over filenames) search_results = arxiv.Search( query=arxiv_query, max_results=max_papers, sort_by=arxiv.SortCriterion.Relevance, sort_order=arxiv.SortOrder.Descending ) pdf_paths = [] for i, result in enumerate(search_results.results()): try: safe_title = re.sub(r'[\\/:*?"<>|]', '', result.title) filename = f"{ARXIV_DIR}/{safe_title[:100]}_{result.arxiv_id}.pdf" print(f"Downloading paper {i+1}/{max_papers}: {result.title}") result.download_pdf(filename=filename) pdf_paths.append(filename) except Exception as e: print(f"Could not download {result.title}: {e}") if not pdf_paths: return "No papers found or downloaded for the given query. Please try a different query." # Load documents from downloaded PDFs using PyPDFLoader all_documents = [] for pdf_path in pdf_paths: try: loader = PyPDFLoader(pdf_path) all_documents.extend(loader.load()) except Exception as e: print(f"Error loading PDF {pdf_path}: {e}") if not all_documents: return "Could not load any documents from downloaded PDFs. Please try a different query or fewer papers." print(f"Loaded {len(all_documents)} raw documents from PDFs.") # Split documents into chunks using RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len, is_separator_regex=False, ) self.knowledge_base_chunks = text_splitter.split_documents(all_documents) if not self.knowledge_base_chunks: return "No meaningful text chunks could be created from the papers after splitting." print(f"Total chunks created: {len(self.knowledge_base_chunks)}") # Create FAISS vectorstore from chunks and embeddings print("Creating FAISS vectorstore from chunks...") self.vectorstore = FAISS.from_documents(self.knowledge_base_chunks, self.embedding_model) print(f"FAISS vectorstore created with {len(self.knowledge_base_chunks)} documents.") # Create RetrievalQA chain self.qa_chain = RetrievalQA.from_chain_type( llm=self.llm, chain_type="stuff", # "stuff" puts all retrieved docs into one prompt retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}), # Retrieve top 3 docs return_source_documents=False # Set to True if you want to return source docs ) return f"Knowledge base loaded with {len(self.knowledge_base_chunks)} chunks from {len(pdf_paths)} arXiv papers on '{arxiv_query}'." except Exception as e: print(f"Error during knowledge base initialization: {e}") return f"An error occurred during knowledge base initialization: {e}" def query_agent(self, query: str) -> str: """ Retrieves relevant information from the knowledge base and generates an answer using the LangChain RetrievalQA chain. """ if not query.strip(): return "Please enter a question." if not self.is_initialized or self.qa_chain is None: return "Knowledge base not loaded. Please initialize it by providing an arXiv query." print(f"\n--- Querying LLM with LangChain QA Chain ---\nQuestion: {query}\n----------------------") try: # Use the RetrievalQA chain to get the answer result = self.qa_chain.invoke({"query": query}) answer = result["result"].strip() except Exception as e: print(f"Error during generation: {e}") answer = "I apologize, but I encountered an error while generating the answer. Please try again or rephrase your question." return answer # --- Gradio Interface --- # Instantiate the RAGAgent rag_agent_instance = RAGAgent() print("Setting up Gradio interface...") with gr.Blocks() as demo: gr.Markdown("# 📚 Educational RAG Agent with arXiv Knowledge Base (LangChain)") gr.Markdown("First, load a knowledge base by specifying an arXiv search query. Then, ask questions!") with gr.Row(): arxiv_input = gr.Textbox( label="arXiv Search Query (e.g., 'Large Language Models', 'Reinforcement Learning')", placeholder="Enter a topic to search for papers on arXiv...", lines=1 ) max_papers_slider = gr.Slider( minimum=1, maximum=10, step=1, value=3, label="Max Papers to Download" ) load_kb_button = gr.Button("Load Knowledge Base from arXiv") kb_status_output = gr.Textbox(label="Knowledge Base Status", interactive=False) with gr.Row(): question_input = gr.Textbox( lines=3, placeholder="Ask a question based on the loaded arXiv papers...", label="Your Question" ) answer_output = gr.Textbox(label="Answer", lines=7, interactive=False) submit_button = gr.Button("Get Answer") load_kb_button.click( fn=rag_agent_instance.initialize_knowledge_base, # Call method of instance inputs=[arxiv_input, max_papers_slider], outputs=kb_status_output ) submit_button.click( fn=rag_agent_instance.query_agent, # Call method of instance inputs=question_input, outputs=answer_output ) gr.Examples( examples=[ ["What is the transformer architecture?"], ["Explain attention mechanisms in deep learning."], ["What are the challenges in reinforcement learning?"], ], inputs=question_input ) # Launch the Gradio app if __name__ == "__main__": print("Launching Gradio app...") demo.launch(share=False)