Jatin Mehra commited on
Commit
1dc0983
·
1 Parent(s): ba76b7d

Refactor PDF processing and embedding creation; update chunking to include metadata

Browse files
Files changed (2) hide show
  1. app.py +16 -15
  2. preprocessing.py +49 -33
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import dotenv
3
  import pickle
4
  import uuid
5
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Request
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.staticfiles import StaticFiles
@@ -16,8 +16,7 @@ from preprocessing import (
16
  build_faiss_index,
17
  retrieve_similar_chunks,
18
  agentic_rag,
19
- tools,
20
- memory
21
  )
22
  from sentence_transformers import SentenceTransformer
23
  import shutil
@@ -88,8 +87,8 @@ def load_session(session_id, model_name="meta-llama/llama-4-scout-17b-16e-instru
88
  # Recreate non-pickled objects
89
  if data.get("chunks") and data.get("file_path") and os.path.exists(data["file_path"]):
90
  # Recreate model, embeddings and index
91
- model = SentenceTransformer('all-MiniLM-L6-v2')
92
- embeddings = create_embeddings(data["chunks"], model)
93
  index = build_faiss_index(embeddings)
94
 
95
  # Recreate LLM
@@ -165,13 +164,15 @@ async def upload_pdf(
165
  raise ValueError("GROQ_API_KEY is not set in the environment variables")
166
 
167
  # Process the PDF
168
- text = process_pdf_file(file_path)
169
- chunks = chunk_text(text, max_length=1500)
170
 
171
  # Create embeddings
172
- model = SentenceTransformer('all-MiniLM-L6-v2')
173
- embeddings = create_embeddings(chunks, model)
174
- index = build_faiss_index(embeddings)
 
 
175
 
176
  # Initialize LLM
177
  llm = model_selection(model_name)
@@ -180,7 +181,7 @@ async def upload_pdf(
180
  session_data = {
181
  "file_path": file_path,
182
  "file_name": file.filename,
183
- "chunks": chunks,
184
  "model": model,
185
  "index": index,
186
  "llm": llm,
@@ -224,16 +225,15 @@ async def chat(request: ChatRequest):
224
  session["index"],
225
  session["chunks"],
226
  session["model"],
227
- k=3
228
  )
229
- context = "\n".join([chunk for chunk, _ in similar_chunks])
230
 
231
  # Generate response using agentic_rag
232
  response = agentic_rag(
233
  session["llm"],
234
  tools,
235
  query=request.query,
236
- context=context,
237
  Use_Tavily=request.use_search
238
  )
239
 
@@ -244,12 +244,13 @@ async def chat(request: ChatRequest):
244
  return {
245
  "status": "success",
246
  "answer": response["output"],
247
- "context_used": [{"text": chunk, "score": float(score)} for chunk, score in similar_chunks]
248
  }
249
 
250
  except Exception as e:
251
  raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
252
 
 
253
  # Route to get chat history
254
  @app.post("/chat-history")
255
  async def get_chat_history(request: SessionRequest):
 
2
  import dotenv
3
  import pickle
4
  import uuid
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.staticfiles import StaticFiles
 
16
  build_faiss_index,
17
  retrieve_similar_chunks,
18
  agentic_rag,
19
+ tools
 
20
  )
21
  from sentence_transformers import SentenceTransformer
22
  import shutil
 
87
  # Recreate non-pickled objects
88
  if data.get("chunks") and data.get("file_path") and os.path.exists(data["file_path"]):
89
  # Recreate model, embeddings and index
90
+ model = SentenceTransformer('BAAI/bge-large-en-v1.5')
91
+ embeddings, _ = create_embeddings(data["chunks"], model) # Unpack tuple
92
  index = build_faiss_index(embeddings)
93
 
94
  # Recreate LLM
 
164
  raise ValueError("GROQ_API_KEY is not set in the environment variables")
165
 
166
  # Process the PDF
167
+ documents = process_pdf_file(file_path) # Returns list of Document objects
168
+ chunks = chunk_text(documents, max_length=1000) # Updated to handle documents
169
 
170
  # Create embeddings
171
+ model = SentenceTransformer('BAAI/bge-large-en-v1.5') # Updated embedding model
172
+ embeddings, chunks_with_metadata = create_embeddings(chunks, model) # Unpack tuple
173
+
174
+ # Build FAISS index
175
+ index = build_faiss_index(embeddings) # Pass only embeddings array
176
 
177
  # Initialize LLM
178
  llm = model_selection(model_name)
 
181
  session_data = {
182
  "file_path": file_path,
183
  "file_name": file.filename,
184
+ "chunks": chunks_with_metadata, # Store chunks with metadata
185
  "model": model,
186
  "index": index,
187
  "llm": llm,
 
225
  session["index"],
226
  session["chunks"],
227
  session["model"],
228
+ k=10
229
  )
 
230
 
231
  # Generate response using agentic_rag
232
  response = agentic_rag(
233
  session["llm"],
234
  tools,
235
  query=request.query,
236
+ context_chunks=similar_chunks, # Pass the list of tuples
237
  Use_Tavily=request.use_search
238
  )
239
 
 
244
  return {
245
  "status": "success",
246
  "answer": response["output"],
247
+ "context_used": [{"text": chunk, "score": float(score)} for chunk, score, _ in similar_chunks]
248
  }
249
 
250
  except Exception as e:
251
  raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
252
 
253
+
254
  # Route to get chat history
255
  @app.post("/chat-history")
256
  async def get_chat_history(request: SessionRequest):
preprocessing.py CHANGED
@@ -25,49 +25,70 @@ def estimate_tokens(text):
25
  return len(text) // 4
26
 
27
  def process_pdf_file(file_path):
28
- """Load a PDF file and extract its text."""
29
  if not os.path.exists(file_path):
30
  raise FileNotFoundError(f"The file {file_path} does not exist.")
31
  loader = PyMuPDFLoader(file_path)
32
  documents = loader.load()
33
- text = "".join(doc.page_content for doc in documents)
34
- return text
35
 
36
- def chunk_text(text, max_length=1500):
37
- """Split text into chunks based on paragraphs, respecting max_length."""
38
- paragraphs = text.split("\n\n")
39
  chunks = []
40
- current_chunk = ""
41
- for paragraph in paragraphs:
42
- if len(current_chunk) + len(paragraph) <= max_length:
43
- current_chunk += paragraph + "\n\n"
44
- else:
45
- chunks.append(current_chunk.strip())
46
- current_chunk = paragraph + "\n\n"
47
- if current_chunk:
48
- chunks.append(current_chunk.strip())
 
 
 
 
 
49
  return chunks
50
 
51
- def create_embeddings(texts, model):
52
- """Create embeddings for a list of texts using the provided model."""
 
53
  embeddings = model.encode(texts, show_progress_bar=True, convert_to_tensor=True)
54
- return embeddings.cpu().numpy()
55
 
56
  def build_faiss_index(embeddings):
57
- """Build a FAISS index from embeddings for similarity search."""
58
  dim = embeddings.shape[1]
59
- index = faiss.IndexFlatL2(dim)
 
 
60
  index.add(embeddings)
61
  return index
62
 
63
- def retrieve_similar_chunks(query, index, texts, model, k=3, max_chunk_length=3500):
64
  """Retrieve top k similar chunks to the query from the FAISS index."""
65
  query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
66
  distances, indices = index.search(query_embedding, k)
67
- return [(texts[i][:max_chunk_length], distances[0][j]) for j, i in enumerate(indices[0])]
68
 
69
- def agentic_rag(llm, tools, query, context, Use_Tavily=False):
70
- # Define the prompt template for the agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  search_instructions = (
72
  "Use the search tool if the context is insufficient to answer the question or you are unsure. Give source links if you use the search tool."
73
  if Use_Tavily
@@ -80,35 +101,30 @@ def agentic_rag(llm, tools, query, context, Use_Tavily=False):
80
  Instructions:
81
  1. Use the provided context to answer the user's question.
82
  2. Provide a clear answer, if you don't know the answer, say 'I don't know'.
 
83
  """),
84
  ("human", "Context: {context}\n\nQuestion: {input}"),
85
  MessagesPlaceholder(variable_name="chat_history"),
86
  MessagesPlaceholder(variable_name="agent_scratchpad"),
87
  ])
88
-
89
- # Only use tools when Tavily is enabled
90
- agent_tools = tools if Use_Tavily else []
91
 
 
92
  try:
93
- # Create the agent and executor with appropriate tools
94
  agent = create_tool_calling_agent(llm, agent_tools, prompt)
95
  agent_executor = AgentExecutor(agent=agent, tools=agent_tools, memory=memory, verbose=True)
96
-
97
- # Execute the agent
98
  return agent_executor.invoke({
99
- "input": query,
100
  "context": context,
101
  "search_instructions": search_instructions
102
  })
103
  except Exception as e:
104
  print(f"Error during agent execution: {str(e)}")
105
- # Fallback to direct LLM call without agent framework
106
  fallback_prompt = ChatPromptTemplate.from_messages([
107
  ("system", "You are a helpful assistant. Use the provided context to answer the user's question."),
108
  ("human", "Context: {context}\n\nQuestion: {input}")
109
  ])
110
  response = llm.invoke(fallback_prompt.format(context=context, input=query))
111
- return {"output": response.content}
112
 
113
  if __name__ == "__main__":
114
  # Process PDF and prepare index
 
25
  return len(text) // 4
26
 
27
  def process_pdf_file(file_path):
28
+ """Load a PDF file and extract its text with metadata."""
29
  if not os.path.exists(file_path):
30
  raise FileNotFoundError(f"The file {file_path} does not exist.")
31
  loader = PyMuPDFLoader(file_path)
32
  documents = loader.load()
33
+ return documents # Return list of Document objects with metadata
 
34
 
35
+ def chunk_text(documents, max_length=1000):
36
+ """Split documents into chunks with metadata."""
 
37
  chunks = []
38
+ for doc in documents:
39
+ text = doc.page_content
40
+ metadata = doc.metadata
41
+ paragraphs = text.split("\n\n")
42
+ current_chunk = ""
43
+ current_metadata = metadata.copy()
44
+ for paragraph in paragraphs:
45
+ if estimate_tokens(current_chunk + paragraph) <= max_length // 4:
46
+ current_chunk += paragraph + "\n\n"
47
+ else:
48
+ chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
49
+ current_chunk = paragraph + "\n\n"
50
+ if current_chunk:
51
+ chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
52
  return chunks
53
 
54
+ def create_embeddings(chunks, model):
55
+ """Create embeddings for a list of chunk texts."""
56
+ texts = [chunk["text"] for chunk in chunks]
57
  embeddings = model.encode(texts, show_progress_bar=True, convert_to_tensor=True)
58
+ return embeddings.cpu().numpy(), chunks
59
 
60
  def build_faiss_index(embeddings):
61
+ """Build a FAISS HNSW index from embeddings for similarity search."""
62
  dim = embeddings.shape[1]
63
+ index = faiss.IndexHNSWFlat(dim, 32) # 32 = number of neighbors in HNSW graph
64
+ index.hnsw.efConstruction = 200 # Higher = better quality, slower build
65
+ index.hnsw.efSearch = 50 # Higher = better accuracy, slower search
66
  index.add(embeddings)
67
  return index
68
 
69
+ def retrieve_similar_chunks(query, index, chunks, model, k=10, max_chunk_length=1000):
70
  """Retrieve top k similar chunks to the query from the FAISS index."""
71
  query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
72
  distances, indices = index.search(query_embedding, k)
73
+ return [(chunks[i]["text"][:max_chunk_length], distances[0][j], chunks[i]["metadata"]) for j, i in enumerate(indices[0])]
74
 
75
+ def agentic_rag(llm, tools, query, context_chunks, Use_Tavily=False):
76
+ # Sort chunks by relevance (lower distance = more relevant)
77
+ context_chunks = sorted(context_chunks, key=lambda x: x[1]) # Sort by distance
78
+ context = ""
79
+ total_tokens = 0
80
+ max_tokens = 7000 # Leave room for prompt and response
81
+
82
+ # Aggregate relevant chunks until token limit is reached
83
+ for chunk, _, _ in context_chunks: # Unpack three elements
84
+ chunk_tokens = estimate_tokens(chunk)
85
+ if total_tokens + chunk_tokens <= max_tokens:
86
+ context += chunk + "\n\n"
87
+ total_tokens += chunk_tokens
88
+ else:
89
+ break
90
+
91
+ # Define prompt template
92
  search_instructions = (
93
  "Use the search tool if the context is insufficient to answer the question or you are unsure. Give source links if you use the search tool."
94
  if Use_Tavily
 
101
  Instructions:
102
  1. Use the provided context to answer the user's question.
103
  2. Provide a clear answer, if you don't know the answer, say 'I don't know'.
104
+ 3. Prioritize information from the most relevant context chunks.
105
  """),
106
  ("human", "Context: {context}\n\nQuestion: {input}"),
107
  MessagesPlaceholder(variable_name="chat_history"),
108
  MessagesPlaceholder(variable_name="agent_scratchpad"),
109
  ])
 
 
 
110
 
111
+ agent_tools = tools if Use_Tavily else []
112
  try:
 
113
  agent = create_tool_calling_agent(llm, agent_tools, prompt)
114
  agent_executor = AgentExecutor(agent=agent, tools=agent_tools, memory=memory, verbose=True)
 
 
115
  return agent_executor.invoke({
116
+ "input": query,
117
  "context": context,
118
  "search_instructions": search_instructions
119
  })
120
  except Exception as e:
121
  print(f"Error during agent execution: {str(e)}")
 
122
  fallback_prompt = ChatPromptTemplate.from_messages([
123
  ("system", "You are a helpful assistant. Use the provided context to answer the user's question."),
124
  ("human", "Context: {context}\n\nQuestion: {input}")
125
  ])
126
  response = llm.invoke(fallback_prompt.format(context=context, input=query))
127
+ return {"output": response.content}
128
 
129
  if __name__ == "__main__":
130
  # Process PDF and prepare index