Spaces:
Running
Running
Jatin Mehra
commited on
Commit
·
1dc0983
1
Parent(s):
ba76b7d
Refactor PDF processing and embedding creation; update chunking to include metadata
Browse files- app.py +16 -15
- preprocessing.py +49 -33
app.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import dotenv
|
3 |
import pickle
|
4 |
import uuid
|
5 |
-
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
6 |
from fastapi.responses import JSONResponse
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
from fastapi.staticfiles import StaticFiles
|
@@ -16,8 +16,7 @@ from preprocessing import (
|
|
16 |
build_faiss_index,
|
17 |
retrieve_similar_chunks,
|
18 |
agentic_rag,
|
19 |
-
tools
|
20 |
-
memory
|
21 |
)
|
22 |
from sentence_transformers import SentenceTransformer
|
23 |
import shutil
|
@@ -88,8 +87,8 @@ def load_session(session_id, model_name="meta-llama/llama-4-scout-17b-16e-instru
|
|
88 |
# Recreate non-pickled objects
|
89 |
if data.get("chunks") and data.get("file_path") and os.path.exists(data["file_path"]):
|
90 |
# Recreate model, embeddings and index
|
91 |
-
model = SentenceTransformer('
|
92 |
-
embeddings = create_embeddings(data["chunks"], model)
|
93 |
index = build_faiss_index(embeddings)
|
94 |
|
95 |
# Recreate LLM
|
@@ -165,13 +164,15 @@ async def upload_pdf(
|
|
165 |
raise ValueError("GROQ_API_KEY is not set in the environment variables")
|
166 |
|
167 |
# Process the PDF
|
168 |
-
|
169 |
-
chunks = chunk_text(
|
170 |
|
171 |
# Create embeddings
|
172 |
-
model = SentenceTransformer('
|
173 |
-
embeddings = create_embeddings(chunks, model)
|
174 |
-
|
|
|
|
|
175 |
|
176 |
# Initialize LLM
|
177 |
llm = model_selection(model_name)
|
@@ -180,7 +181,7 @@ async def upload_pdf(
|
|
180 |
session_data = {
|
181 |
"file_path": file_path,
|
182 |
"file_name": file.filename,
|
183 |
-
"chunks": chunks
|
184 |
"model": model,
|
185 |
"index": index,
|
186 |
"llm": llm,
|
@@ -224,16 +225,15 @@ async def chat(request: ChatRequest):
|
|
224 |
session["index"],
|
225 |
session["chunks"],
|
226 |
session["model"],
|
227 |
-
k=
|
228 |
)
|
229 |
-
context = "\n".join([chunk for chunk, _ in similar_chunks])
|
230 |
|
231 |
# Generate response using agentic_rag
|
232 |
response = agentic_rag(
|
233 |
session["llm"],
|
234 |
tools,
|
235 |
query=request.query,
|
236 |
-
|
237 |
Use_Tavily=request.use_search
|
238 |
)
|
239 |
|
@@ -244,12 +244,13 @@ async def chat(request: ChatRequest):
|
|
244 |
return {
|
245 |
"status": "success",
|
246 |
"answer": response["output"],
|
247 |
-
"context_used": [{"text": chunk, "score": float(score)} for chunk, score in similar_chunks]
|
248 |
}
|
249 |
|
250 |
except Exception as e:
|
251 |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
252 |
|
|
|
253 |
# Route to get chat history
|
254 |
@app.post("/chat-history")
|
255 |
async def get_chat_history(request: SessionRequest):
|
|
|
2 |
import dotenv
|
3 |
import pickle
|
4 |
import uuid
|
5 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
6 |
from fastapi.responses import JSONResponse
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
from fastapi.staticfiles import StaticFiles
|
|
|
16 |
build_faiss_index,
|
17 |
retrieve_similar_chunks,
|
18 |
agentic_rag,
|
19 |
+
tools
|
|
|
20 |
)
|
21 |
from sentence_transformers import SentenceTransformer
|
22 |
import shutil
|
|
|
87 |
# Recreate non-pickled objects
|
88 |
if data.get("chunks") and data.get("file_path") and os.path.exists(data["file_path"]):
|
89 |
# Recreate model, embeddings and index
|
90 |
+
model = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
91 |
+
embeddings, _ = create_embeddings(data["chunks"], model) # Unpack tuple
|
92 |
index = build_faiss_index(embeddings)
|
93 |
|
94 |
# Recreate LLM
|
|
|
164 |
raise ValueError("GROQ_API_KEY is not set in the environment variables")
|
165 |
|
166 |
# Process the PDF
|
167 |
+
documents = process_pdf_file(file_path) # Returns list of Document objects
|
168 |
+
chunks = chunk_text(documents, max_length=1000) # Updated to handle documents
|
169 |
|
170 |
# Create embeddings
|
171 |
+
model = SentenceTransformer('BAAI/bge-large-en-v1.5') # Updated embedding model
|
172 |
+
embeddings, chunks_with_metadata = create_embeddings(chunks, model) # Unpack tuple
|
173 |
+
|
174 |
+
# Build FAISS index
|
175 |
+
index = build_faiss_index(embeddings) # Pass only embeddings array
|
176 |
|
177 |
# Initialize LLM
|
178 |
llm = model_selection(model_name)
|
|
|
181 |
session_data = {
|
182 |
"file_path": file_path,
|
183 |
"file_name": file.filename,
|
184 |
+
"chunks": chunks_with_metadata, # Store chunks with metadata
|
185 |
"model": model,
|
186 |
"index": index,
|
187 |
"llm": llm,
|
|
|
225 |
session["index"],
|
226 |
session["chunks"],
|
227 |
session["model"],
|
228 |
+
k=10
|
229 |
)
|
|
|
230 |
|
231 |
# Generate response using agentic_rag
|
232 |
response = agentic_rag(
|
233 |
session["llm"],
|
234 |
tools,
|
235 |
query=request.query,
|
236 |
+
context_chunks=similar_chunks, # Pass the list of tuples
|
237 |
Use_Tavily=request.use_search
|
238 |
)
|
239 |
|
|
|
244 |
return {
|
245 |
"status": "success",
|
246 |
"answer": response["output"],
|
247 |
+
"context_used": [{"text": chunk, "score": float(score)} for chunk, score, _ in similar_chunks]
|
248 |
}
|
249 |
|
250 |
except Exception as e:
|
251 |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
252 |
|
253 |
+
|
254 |
# Route to get chat history
|
255 |
@app.post("/chat-history")
|
256 |
async def get_chat_history(request: SessionRequest):
|
preprocessing.py
CHANGED
@@ -25,49 +25,70 @@ def estimate_tokens(text):
|
|
25 |
return len(text) // 4
|
26 |
|
27 |
def process_pdf_file(file_path):
|
28 |
-
"""Load a PDF file and extract its text."""
|
29 |
if not os.path.exists(file_path):
|
30 |
raise FileNotFoundError(f"The file {file_path} does not exist.")
|
31 |
loader = PyMuPDFLoader(file_path)
|
32 |
documents = loader.load()
|
33 |
-
|
34 |
-
return text
|
35 |
|
36 |
-
def chunk_text(
|
37 |
-
"""Split
|
38 |
-
paragraphs = text.split("\n\n")
|
39 |
chunks = []
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
return chunks
|
50 |
|
51 |
-
def create_embeddings(
|
52 |
-
"""Create embeddings for a list of texts
|
|
|
53 |
embeddings = model.encode(texts, show_progress_bar=True, convert_to_tensor=True)
|
54 |
-
return embeddings.cpu().numpy()
|
55 |
|
56 |
def build_faiss_index(embeddings):
|
57 |
-
"""Build a FAISS index from embeddings for similarity search."""
|
58 |
dim = embeddings.shape[1]
|
59 |
-
index = faiss.
|
|
|
|
|
60 |
index.add(embeddings)
|
61 |
return index
|
62 |
|
63 |
-
def retrieve_similar_chunks(query, index,
|
64 |
"""Retrieve top k similar chunks to the query from the FAISS index."""
|
65 |
query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
|
66 |
distances, indices = index.search(query_embedding, k)
|
67 |
-
return [(
|
68 |
|
69 |
-
def agentic_rag(llm, tools, query,
|
70 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
search_instructions = (
|
72 |
"Use the search tool if the context is insufficient to answer the question or you are unsure. Give source links if you use the search tool."
|
73 |
if Use_Tavily
|
@@ -80,35 +101,30 @@ def agentic_rag(llm, tools, query, context, Use_Tavily=False):
|
|
80 |
Instructions:
|
81 |
1. Use the provided context to answer the user's question.
|
82 |
2. Provide a clear answer, if you don't know the answer, say 'I don't know'.
|
|
|
83 |
"""),
|
84 |
("human", "Context: {context}\n\nQuestion: {input}"),
|
85 |
MessagesPlaceholder(variable_name="chat_history"),
|
86 |
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
87 |
])
|
88 |
-
|
89 |
-
# Only use tools when Tavily is enabled
|
90 |
-
agent_tools = tools if Use_Tavily else []
|
91 |
|
|
|
92 |
try:
|
93 |
-
# Create the agent and executor with appropriate tools
|
94 |
agent = create_tool_calling_agent(llm, agent_tools, prompt)
|
95 |
agent_executor = AgentExecutor(agent=agent, tools=agent_tools, memory=memory, verbose=True)
|
96 |
-
|
97 |
-
# Execute the agent
|
98 |
return agent_executor.invoke({
|
99 |
-
"input": query,
|
100 |
"context": context,
|
101 |
"search_instructions": search_instructions
|
102 |
})
|
103 |
except Exception as e:
|
104 |
print(f"Error during agent execution: {str(e)}")
|
105 |
-
# Fallback to direct LLM call without agent framework
|
106 |
fallback_prompt = ChatPromptTemplate.from_messages([
|
107 |
("system", "You are a helpful assistant. Use the provided context to answer the user's question."),
|
108 |
("human", "Context: {context}\n\nQuestion: {input}")
|
109 |
])
|
110 |
response = llm.invoke(fallback_prompt.format(context=context, input=query))
|
111 |
-
return {"output": response.content}
|
112 |
|
113 |
if __name__ == "__main__":
|
114 |
# Process PDF and prepare index
|
|
|
25 |
return len(text) // 4
|
26 |
|
27 |
def process_pdf_file(file_path):
|
28 |
+
"""Load a PDF file and extract its text with metadata."""
|
29 |
if not os.path.exists(file_path):
|
30 |
raise FileNotFoundError(f"The file {file_path} does not exist.")
|
31 |
loader = PyMuPDFLoader(file_path)
|
32 |
documents = loader.load()
|
33 |
+
return documents # Return list of Document objects with metadata
|
|
|
34 |
|
35 |
+
def chunk_text(documents, max_length=1000):
|
36 |
+
"""Split documents into chunks with metadata."""
|
|
|
37 |
chunks = []
|
38 |
+
for doc in documents:
|
39 |
+
text = doc.page_content
|
40 |
+
metadata = doc.metadata
|
41 |
+
paragraphs = text.split("\n\n")
|
42 |
+
current_chunk = ""
|
43 |
+
current_metadata = metadata.copy()
|
44 |
+
for paragraph in paragraphs:
|
45 |
+
if estimate_tokens(current_chunk + paragraph) <= max_length // 4:
|
46 |
+
current_chunk += paragraph + "\n\n"
|
47 |
+
else:
|
48 |
+
chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
|
49 |
+
current_chunk = paragraph + "\n\n"
|
50 |
+
if current_chunk:
|
51 |
+
chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
|
52 |
return chunks
|
53 |
|
54 |
+
def create_embeddings(chunks, model):
|
55 |
+
"""Create embeddings for a list of chunk texts."""
|
56 |
+
texts = [chunk["text"] for chunk in chunks]
|
57 |
embeddings = model.encode(texts, show_progress_bar=True, convert_to_tensor=True)
|
58 |
+
return embeddings.cpu().numpy(), chunks
|
59 |
|
60 |
def build_faiss_index(embeddings):
|
61 |
+
"""Build a FAISS HNSW index from embeddings for similarity search."""
|
62 |
dim = embeddings.shape[1]
|
63 |
+
index = faiss.IndexHNSWFlat(dim, 32) # 32 = number of neighbors in HNSW graph
|
64 |
+
index.hnsw.efConstruction = 200 # Higher = better quality, slower build
|
65 |
+
index.hnsw.efSearch = 50 # Higher = better accuracy, slower search
|
66 |
index.add(embeddings)
|
67 |
return index
|
68 |
|
69 |
+
def retrieve_similar_chunks(query, index, chunks, model, k=10, max_chunk_length=1000):
|
70 |
"""Retrieve top k similar chunks to the query from the FAISS index."""
|
71 |
query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
|
72 |
distances, indices = index.search(query_embedding, k)
|
73 |
+
return [(chunks[i]["text"][:max_chunk_length], distances[0][j], chunks[i]["metadata"]) for j, i in enumerate(indices[0])]
|
74 |
|
75 |
+
def agentic_rag(llm, tools, query, context_chunks, Use_Tavily=False):
|
76 |
+
# Sort chunks by relevance (lower distance = more relevant)
|
77 |
+
context_chunks = sorted(context_chunks, key=lambda x: x[1]) # Sort by distance
|
78 |
+
context = ""
|
79 |
+
total_tokens = 0
|
80 |
+
max_tokens = 7000 # Leave room for prompt and response
|
81 |
+
|
82 |
+
# Aggregate relevant chunks until token limit is reached
|
83 |
+
for chunk, _, _ in context_chunks: # Unpack three elements
|
84 |
+
chunk_tokens = estimate_tokens(chunk)
|
85 |
+
if total_tokens + chunk_tokens <= max_tokens:
|
86 |
+
context += chunk + "\n\n"
|
87 |
+
total_tokens += chunk_tokens
|
88 |
+
else:
|
89 |
+
break
|
90 |
+
|
91 |
+
# Define prompt template
|
92 |
search_instructions = (
|
93 |
"Use the search tool if the context is insufficient to answer the question or you are unsure. Give source links if you use the search tool."
|
94 |
if Use_Tavily
|
|
|
101 |
Instructions:
|
102 |
1. Use the provided context to answer the user's question.
|
103 |
2. Provide a clear answer, if you don't know the answer, say 'I don't know'.
|
104 |
+
3. Prioritize information from the most relevant context chunks.
|
105 |
"""),
|
106 |
("human", "Context: {context}\n\nQuestion: {input}"),
|
107 |
MessagesPlaceholder(variable_name="chat_history"),
|
108 |
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
109 |
])
|
|
|
|
|
|
|
110 |
|
111 |
+
agent_tools = tools if Use_Tavily else []
|
112 |
try:
|
|
|
113 |
agent = create_tool_calling_agent(llm, agent_tools, prompt)
|
114 |
agent_executor = AgentExecutor(agent=agent, tools=agent_tools, memory=memory, verbose=True)
|
|
|
|
|
115 |
return agent_executor.invoke({
|
116 |
+
"input": query,
|
117 |
"context": context,
|
118 |
"search_instructions": search_instructions
|
119 |
})
|
120 |
except Exception as e:
|
121 |
print(f"Error during agent execution: {str(e)}")
|
|
|
122 |
fallback_prompt = ChatPromptTemplate.from_messages([
|
123 |
("system", "You are a helpful assistant. Use the provided context to answer the user's question."),
|
124 |
("human", "Context: {context}\n\nQuestion: {input}")
|
125 |
])
|
126 |
response = llm.invoke(fallback_prompt.format(context=context, input=query))
|
127 |
+
return {"output": response.content}
|
128 |
|
129 |
if __name__ == "__main__":
|
130 |
# Process PDF and prepare index
|