APrmn8 commited on
Commit
08a44c9
·
verified ·
1 Parent(s): 8a7cf31

revise app

Browse files
Files changed (1) hide show
  1. app.py +124 -32
app.py CHANGED
@@ -4,7 +4,10 @@ import os
4
  import re
5
  import shutil
6
  import torch
7
- from langchain_community.document_loaders import ArxivLoader, PyPDFLoader
 
 
 
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.embeddings import HuggingFaceEmbeddings
10
  from langchain_community.vectorstores import FAISS
@@ -14,11 +17,57 @@ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
14
 
15
  # --- Configuration ---
16
  ARXIV_DIR = "./arxiv_papers" # Directory to save downloaded papers
 
 
 
 
17
  CHUNK_SIZE = 500 # Characters per chunk
18
  CHUNK_OVERLAP = 50 # Overlap between chunks
19
  EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2'
20
  LLM_MODEL_NAME = "google/flan-t5-small"
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # --- RAGAgent Class ---
23
 
24
  class RAGAgent:
@@ -72,33 +121,12 @@ class RAGAgent:
72
 
73
  self.vectorstore = None
74
  self.qa_chain = None
 
75
 
76
  print(f"Searching arXiv for '{arxiv_query}' and downloading up to {max_papers} papers...")
77
  try:
78
- # Use LangChain's ArxivLoader
79
- # ArxivLoader downloads PDFs to a temporary directory by default,
80
- # but we can specify a custom path to ensure cleanup.
81
- # For simplicity, we'll let it download to its default temp dir
82
- # and then process. Or, we can manually download and use PyPDFLoader.
83
- # Let's stick to manual download for better control and consistency with previous code.
84
-
85
  # Manual download using arxiv library (as it offers more control over filenames)
86
- search_results = arxiv.Search(
87
- query=arxiv_query,
88
- max_results=max_papers,
89
- sort_by=arxiv.SortCriterion.Relevance,
90
- sort_order=arxiv.SortOrder.Descending
91
- )
92
- pdf_paths = []
93
- for i, result in enumerate(search_results.results()):
94
- try:
95
- safe_title = re.sub(r'[\\/:*?"<>|]', '', result.title)
96
- filename = f"{ARXIV_DIR}/{safe_title[:100]}_{result.arxiv_id}.pdf"
97
- print(f"Downloading paper {i+1}/{max_papers}: {result.title}")
98
- result.download_pdf(filename=filename)
99
- pdf_paths.append(filename)
100
- except Exception as e:
101
- print(f"Could not download {result.title}: {e}")
102
 
103
  if not pdf_paths:
104
  return "No papers found or downloaded for the given query. Please try a different query."
@@ -150,6 +178,54 @@ class RAGAgent:
150
  print(f"Error during knowledge base initialization: {e}")
151
  return f"An error occurred during knowledge base initialization: {e}"
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def query_agent(self, query: str) -> str:
154
  """
155
  Retrieves relevant information from the knowledge base and generates an answer
@@ -158,7 +234,7 @@ class RAGAgent:
158
  if not query.strip():
159
  return "Please enter a question."
160
  if not self.is_initialized or self.qa_chain is None:
161
- return "Knowledge base not loaded. Please initialize it by providing an arXiv query."
162
 
163
  print(f"\n--- Querying LLM with LangChain QA Chain ---\nQuestion: {query}\n----------------------")
164
 
@@ -180,8 +256,8 @@ rag_agent_instance = RAGAgent()
180
  print("Setting up Gradio interface...")
181
 
182
  with gr.Blocks() as demo:
183
- gr.Markdown("# 📚 Educational RAG Agent with arXiv Knowledge Base (LangChain)")
184
- gr.Markdown("First, load a knowledge base by specifying an arXiv search query. Then, ask questions!")
185
 
186
  with gr.Row():
187
  arxiv_input = gr.Textbox(
@@ -196,28 +272,44 @@ with gr.Blocks() as demo:
196
  value=3,
197
  label="Max Papers to Download"
198
  )
199
- load_kb_button = gr.Button("Load Knowledge Base from arXiv")
200
 
201
  kb_status_output = gr.Textbox(label="Knowledge Base Status", interactive=False)
202
 
 
 
 
 
203
  with gr.Row():
204
  question_input = gr.Textbox(
205
  lines=3,
206
- placeholder="Ask a question based on the loaded arXiv papers...",
207
  label="Your Question"
208
  )
209
  answer_output = gr.Textbox(label="Answer", lines=7, interactive=False)
210
 
211
  submit_button = gr.Button("Get Answer")
212
 
213
- load_kb_button.click(
214
- fn=rag_agent_instance.initialize_knowledge_base, # Call method of instance
215
  inputs=[arxiv_input, max_papers_slider],
216
  outputs=kb_status_output
217
  )
218
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  submit_button.click(
220
- fn=rag_agent_instance.query_agent, # Call method of instance
221
  inputs=question_input,
222
  outputs=answer_output
223
  )
 
4
  import re
5
  import shutil
6
  import torch
7
+ import pickle # For saving/loading Python objects
8
+
9
+ # LangChain imports
10
+ from langchain_community.document_loaders import PyPDFLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
  from langchain_community.vectorstores import FAISS
 
17
 
18
  # --- Configuration ---
19
  ARXIV_DIR = "./arxiv_papers" # Directory to save downloaded papers
20
+ KB_STORAGE_DIR = "./knowledge_base_storage" # Directory to save/load KB
21
+ FAISS_INDEX_PATH = os.path.join(KB_STORAGE_DIR, "faiss_index.bin")
22
+ CHUNKS_PATH = os.path.join(KB_STORAGE_DIR, "knowledge_base_chunks.pkl")
23
+
24
  CHUNK_SIZE = 500 # Characters per chunk
25
  CHUNK_OVERLAP = 50 # Overlap between chunks
26
  EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2'
27
  LLM_MODEL_NAME = "google/flan-t5-small"
28
 
29
+ # Ensure KB storage directory exists
30
+ os.makedirs(KB_STORAGE_DIR, exist_ok=True)
31
+
32
+ # --- Helper Functions for arXiv and PDF Processing ---
33
+
34
+ def clean_text(text: str) -> str:
35
+ """Basic text cleaning: replaces multiple spaces/newlines with single space and strips whitespace."""
36
+ text = re.sub(r'\s+', ' ', text)
37
+ text = text.strip()
38
+ return text
39
+
40
+ def get_arxiv_papers(query: str, max_papers: int = 5) -> list[str]:
41
+ """
42
+ Searches arXiv for papers, downloads their PDFs, and returns a list of file paths.
43
+ Clears the ARXIV_DIR before downloading new papers.
44
+ """
45
+ # Clear existing papers before downloading new ones
46
+ if os.path.exists(ARXIV_DIR):
47
+ shutil.rmtree(ARXIV_DIR)
48
+ os.makedirs(ARXIV_DIR, exist_ok=True)
49
+
50
+ print(f"Searching arXiv for '{query}' and downloading up to {max_papers} papers...")
51
+ import arxiv # Import here to ensure it's available when this function is called
52
+ search_results = arxiv.Search(
53
+ query=query,
54
+ max_results=max_papers,
55
+ sort_by=arxiv.SortCriterion.Relevance,
56
+ sort_order=arxiv.SortOrder.Descending
57
+ )
58
+ downloaded_files = []
59
+ for i, result in enumerate(search_results.results()):
60
+ try:
61
+ # Create a safe filename
62
+ safe_title = re.sub(r'[\\/:*?"<>|]', '', result.title) # Remove invalid characters
63
+ filename = f"{ARXIV_DIR}/{safe_title[:100]}_{result.arxiv_id}.pdf" # Limit title length
64
+ print(f"Downloading paper {i+1}/{max_papers}: {result.title}")
65
+ result.download_pdf(filename=filename)
66
+ downloaded_files.append(filename)
67
+ except Exception as e:
68
+ print(f"Could not download {result.title}: {e}")
69
+ return downloaded_files
70
+
71
  # --- RAGAgent Class ---
72
 
73
  class RAGAgent:
 
121
 
122
  self.vectorstore = None
123
  self.qa_chain = None
124
+ self.knowledge_base_chunks = [] # Reset chunks
125
 
126
  print(f"Searching arXiv for '{arxiv_query}' and downloading up to {max_papers} papers...")
127
  try:
 
 
 
 
 
 
 
128
  # Manual download using arxiv library (as it offers more control over filenames)
129
+ pdf_paths = get_arxiv_papers(arxiv_query, max_papers) # Call the helper function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  if not pdf_paths:
132
  return "No papers found or downloaded for the given query. Please try a different query."
 
178
  print(f"Error during knowledge base initialization: {e}")
179
  return f"An error occurred during knowledge base initialization: {e}"
180
 
181
+ def save_knowledge_base(self) -> str:
182
+ """Saves the current FAISS vectorstore and knowledge base chunks to disk."""
183
+ if not self.vectorstore or not self.knowledge_base_chunks:
184
+ return "No knowledge base to save. Please load one first."
185
+
186
+ try:
187
+ # Save FAISS index
188
+ self.vectorstore.save_local(KB_STORAGE_DIR, index_name="faiss_index")
189
+ # Save chunks (metadata for FAISS, or for re-building if needed)
190
+ with open(CHUNKS_PATH, 'wb') as f:
191
+ pickle.dump(self.knowledge_base_chunks, f)
192
+ print(f"Knowledge base saved to {KB_STORAGE_DIR}")
193
+ return f"Knowledge base saved successfully to {KB_STORAGE_DIR}."
194
+ except Exception as e:
195
+ print(f"Error saving knowledge base: {e}")
196
+ return f"Error saving knowledge base: {e}"
197
+
198
+ def load_knowledge_base(self) -> str:
199
+ """Loads the FAISS vectorstore and knowledge base chunks from disk."""
200
+ self._load_models() # Ensure models are loaded before loading KB
201
+
202
+ if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(CHUNKS_PATH):
203
+ return "Saved knowledge base not found. Please load or create one first."
204
+
205
+ try:
206
+ # Load FAISS index
207
+ self.vectorstore = FAISS.load_local(KB_STORAGE_DIR, self.embedding_model, index_name="faiss_index", allow_dangerous_deserialization=True)
208
+ # Load chunks
209
+ with open(CHUNKS_PATH, 'rb') as f:
210
+ self.knowledge_base_chunks = pickle.load(f)
211
+
212
+ # Re-create RetrievalQA chain after loading vectorstore
213
+ self.qa_chain = RetrievalQA.from_chain_type(
214
+ llm=self.llm,
215
+ chain_type="stuff",
216
+ retriever=self.vectorstore.as_retriever(search_kwargs={"k": 3}),
217
+ return_source_documents=False
218
+ )
219
+
220
+ print(f"Knowledge base loaded from {KB_STORAGE_DIR}")
221
+ return f"Knowledge base loaded successfully from {KB_STORAGE_DIR} with {len(self.knowledge_base_chunks)} chunks."
222
+ except Exception as e:
223
+ print(f"Error loading knowledge base: {e}")
224
+ self.vectorstore = None
225
+ self.qa_chain = None
226
+ self.knowledge_base_chunks = []
227
+ return f"Error loading knowledge base: {e}"
228
+
229
  def query_agent(self, query: str) -> str:
230
  """
231
  Retrieves relevant information from the knowledge base and generates an answer
 
234
  if not query.strip():
235
  return "Please enter a question."
236
  if not self.is_initialized or self.qa_chain is None:
237
+ return "Knowledge base not loaded. Please initialize it by providing an arXiv query or loading from disk."
238
 
239
  print(f"\n--- Querying LLM with LangChain QA Chain ---\nQuestion: {query}\n----------------------")
240
 
 
256
  print("Setting up Gradio interface...")
257
 
258
  with gr.Blocks() as demo:
259
+ gr.Markdown("# 📚 Educational RAG Agent with Persistent Knowledge Base")
260
+ gr.Markdown("First, load a knowledge base from arXiv, then you can save it or load a previously saved one. Finally, ask questions!")
261
 
262
  with gr.Row():
263
  arxiv_input = gr.Textbox(
 
272
  value=3,
273
  label="Max Papers to Download"
274
  )
275
+ load_kb_from_arxiv_button = gr.Button("Load KB from arXiv")
276
 
277
  kb_status_output = gr.Textbox(label="Knowledge Base Status", interactive=False)
278
 
279
+ with gr.Row():
280
+ save_kb_button = gr.Button("Save Knowledge Base to Disk")
281
+ load_kb_from_disk_button = gr.Button("Load Knowledge Base from Disk")
282
+
283
  with gr.Row():
284
  question_input = gr.Textbox(
285
  lines=3,
286
+ placeholder="Ask a question based on the loaded knowledge base...",
287
  label="Your Question"
288
  )
289
  answer_output = gr.Textbox(label="Answer", lines=7, interactive=False)
290
 
291
  submit_button = gr.Button("Get Answer")
292
 
293
+ load_kb_from_arxiv_button.click(
294
+ fn=rag_agent_instance.initialize_knowledge_base,
295
  inputs=[arxiv_input, max_papers_slider],
296
  outputs=kb_status_output
297
  )
298
 
299
+ save_kb_button.click(
300
+ fn=rag_agent_instance.save_knowledge_base,
301
+ inputs=[],
302
+ outputs=kb_status_output
303
+ )
304
+
305
+ load_kb_from_disk_button.click(
306
+ fn=rag_agent_instance.load_knowledge_base,
307
+ inputs=[],
308
+ outputs=kb_status_output
309
+ )
310
+
311
  submit_button.click(
312
+ fn=rag_agent_instance.query_agent,
313
  inputs=question_input,
314
  outputs=answer_output
315
  )