Jatin Mehra commited on
Commit
33c5afb
·
1 Parent(s): 447c09c

Refactor FastAPI application for improved modularity and maintainability

Browse files
Files changed (2) hide show
  1. app.py +0 -357
  2. app_refactored.py +107 -0
app.py DELETED
@@ -1,357 +0,0 @@
1
- import os
2
- import dotenv
3
- import pickle
4
- import uuid
5
- import shutil
6
- import traceback
7
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
8
- from fastapi.responses import JSONResponse
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.staticfiles import StaticFiles
11
- from pydantic import BaseModel
12
- import uvicorn
13
- from preprocessing import (
14
- model_selection,
15
- process_pdf_file,
16
- chunk_text,
17
- create_embeddings,
18
- build_faiss_index,
19
- retrieve_similar_chunks,
20
- agentic_rag,
21
- tools as global_base_tools,
22
- create_vector_search_tool
23
- )
24
- from sentence_transformers import SentenceTransformer
25
- from langchain.memory import ConversationBufferMemory
26
-
27
- # Load environment variables
28
- dotenv.load_dotenv()
29
-
30
- # Initialize FastAPI app
31
- app = FastAPI(title="PDF Insight Beta", description="Agentic RAG for PDF documents")
32
-
33
- # Add CORS middleware
34
- app.add_middleware(
35
- CORSMiddleware,
36
- allow_origins=["*"],
37
- allow_credentials=True,
38
- allow_methods=["*"],
39
- allow_headers=["*"],
40
- )
41
-
42
- # Create upload directory if it doesn't exist
43
- UPLOAD_DIR = "uploads"
44
- if not os.path.exists(UPLOAD_DIR):
45
- os.makedirs(UPLOAD_DIR)
46
-
47
- # Store active sessions
48
- sessions = {}
49
-
50
- # Define model for chat request
51
- class ChatRequest(BaseModel):
52
- session_id: str
53
- query: str
54
- use_search: bool = False
55
- model_name: str = "meta-llama/llama-4-scout-17b-16e-instruct"
56
-
57
- class SessionRequest(BaseModel):
58
- session_id: str
59
-
60
- # Function to save session data
61
- def save_session(session_id, data):
62
- sessions[session_id] = data # Keep non-picklable in memory for active session
63
-
64
- pickle_safe_data = {
65
- "file_path": data.get("file_path"),
66
- "file_name": data.get("file_name"),
67
- "chunks": data.get("chunks"), # Chunks with metadata (list of dicts)
68
- "chat_history": data.get("chat_history", [])
69
- # FAISS index, embedding model, and LLM model are not pickled, will be reloaded/recreated
70
- }
71
-
72
- with open(f"{UPLOAD_DIR}/{session_id}_session.pkl", "wb") as f:
73
- pickle.dump(pickle_safe_data, f)
74
-
75
-
76
- # Function to load session data
77
- def load_session(session_id, model_name="llama3-8b-8192"): # Ensure model_name matches default
78
- try:
79
- if session_id in sessions:
80
- cached_session = sessions[session_id]
81
- # Ensure LLM and potentially other non-pickled parts are up-to-date or loaded
82
- if cached_session.get("llm") is None or (hasattr(cached_session["llm"], "model_name") and cached_session["llm"].model_name != model_name):
83
- cached_session["llm"] = model_selection(model_name)
84
- if cached_session.get("model") is None: # Embedding model
85
- cached_session["model"] = SentenceTransformer('BAAI/bge-large-en-v1.5')
86
- if cached_session.get("index") is None and cached_session.get("chunks"): # FAISS index
87
- embeddings, _ = create_embeddings(cached_session["chunks"], cached_session["model"])
88
- cached_session["index"] = build_faiss_index(embeddings)
89
- return cached_session, True
90
-
91
- file_path_pkl = f"{UPLOAD_DIR}/{session_id}_session.pkl"
92
- if os.path.exists(file_path_pkl):
93
- with open(file_path_pkl, "rb") as f:
94
- data = pickle.load(f)
95
-
96
- original_pdf_path = data.get("file_path")
97
- if data.get("chunks") and original_pdf_path and os.path.exists(original_pdf_path):
98
- embedding_model_instance = SentenceTransformer('BAAI/bge-large-en-v1.5')
99
- # Chunks are already {text: ..., metadata: ...}
100
- recreated_embeddings, _ = create_embeddings(data["chunks"], embedding_model_instance)
101
- recreated_index = build_faiss_index(recreated_embeddings)
102
- recreated_llm = model_selection(model_name)
103
-
104
- full_session_data = {
105
- "file_path": original_pdf_path,
106
- "file_name": data.get("file_name"),
107
- "chunks": data.get("chunks"), # chunks_with_metadata
108
- "chat_history": data.get("chat_history", []),
109
- "model": embedding_model_instance, # SentenceTransformer model
110
- "index": recreated_index, # FAISS index
111
- "llm": recreated_llm # LLM
112
- }
113
- sessions[session_id] = full_session_data
114
- return full_session_data, True
115
- else:
116
- print(f"Warning: Session data for {session_id} is incomplete or PDF missing. Cannot reconstruct.")
117
- if os.path.exists(file_path_pkl): os.remove(file_path_pkl) # Clean up stale pkl
118
- return None, False
119
-
120
- return None, False
121
- except Exception as e:
122
- print(f"Error loading session {session_id}: {str(e)}")
123
- print(traceback.format_exc())
124
- return None, False
125
-
126
- # Function to remove PDF file
127
- def remove_pdf_file(session_id):
128
- try:
129
- # Check if the session exists
130
- session_path = f"{UPLOAD_DIR}/{session_id}_session.pkl"
131
- if os.path.exists(session_path):
132
- # Load session data
133
- with open(session_path, "rb") as f:
134
- data = pickle.load(f)
135
-
136
- # Delete PDF file if it exists
137
- if data.get("file_path") and os.path.exists(data["file_path"]):
138
- os.remove(data["file_path"])
139
-
140
- # Remove session file
141
- os.remove(session_path)
142
-
143
- # Remove from memory if exists
144
- if session_id in sessions:
145
- del sessions[session_id]
146
-
147
- return True
148
- except Exception as e:
149
- print(f"Error removing PDF file: {str(e)}")
150
- return False
151
-
152
- # Mount static files (we'll create these later)
153
- app.mount("/static", StaticFiles(directory="static"), name="static")
154
-
155
- # Route for the home page
156
- @app.get("/")
157
- async def read_root():
158
- from fastapi.responses import RedirectResponse
159
- return RedirectResponse(url="/static/index.html")
160
-
161
- # Route to upload a PDF file
162
- @app.post("/upload-pdf")
163
- async def upload_pdf(
164
- file: UploadFile = File(...),
165
- model_name: str = Form("llama3-8b-8192") # Default model
166
- ):
167
- session_id = str(uuid.uuid4())
168
- file_path = None
169
-
170
- try:
171
- file_path = f"{UPLOAD_DIR}/{session_id}_{file.filename}"
172
- with open(file_path, "wb") as buffer:
173
- shutil.copyfileobj(file.file, buffer)
174
-
175
- if not os.getenv("GROQ_API_KEY") and "llama" in model_name: # Llama specific check for Groq
176
- raise ValueError("GROQ_API_KEY is not set for Groq Llama models.")
177
- if not os.getenv("TAVILY_API_KEY"): # Needed for TavilySearchResults
178
- print("Warning: TAVILY_API_KEY is not set. Web search will not function.")
179
-
180
- documents = process_pdf_file(file_path)
181
- chunks_with_metadata = chunk_text(documents, max_length=1000) # Increased from 256 to 1000 tokens for better context
182
-
183
- embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
184
- embeddings, _ = create_embeddings(chunks_with_metadata, embedding_model) # Chunks are already with metadata
185
-
186
- index = build_faiss_index(embeddings)
187
- llm = model_selection(model_name)
188
-
189
- session_data = {
190
- "file_path": file_path,
191
- "file_name": file.filename,
192
- "chunks": chunks_with_metadata, # Store chunks with metadata
193
- "model": embedding_model, # SentenceTransformer instance
194
- "index": index, # FAISS index instance
195
- "llm": llm, # LLM instance
196
- "chat_history": []
197
- }
198
- save_session(session_id, session_data)
199
-
200
- return {"status": "success", "session_id": session_id, "message": f"Processed {file.filename}"}
201
-
202
- except Exception as e:
203
- if file_path and os.path.exists(file_path):
204
- os.remove(file_path)
205
- error_msg = str(e)
206
- stack_trace = traceback.format_exc()
207
- print(f"Error processing PDF: {error_msg}\nStack trace: {stack_trace}")
208
- return JSONResponse(
209
- status_code=500, # Internal server error for processing issues
210
- content={"status": "error", "detail": error_msg, "type": type(e).__name__}
211
- )
212
-
213
- # Route to chat with the document
214
- @app.post("/chat")
215
- async def chat(request: ChatRequest):
216
- # Validate query
217
- if not request.query or not request.query.strip():
218
- raise HTTPException(status_code=400, detail="Query cannot be empty")
219
-
220
- if len(request.query.strip()) < 3:
221
- raise HTTPException(status_code=400, detail="Query must be at least 3 characters long")
222
-
223
- session, found = load_session(request.session_id, model_name=request.model_name)
224
- if not found:
225
- raise HTTPException(status_code=404, detail="Session not found or expired. Please upload a document first.")
226
-
227
- try:
228
- # Validate session data integrity
229
- required_keys = ["index", "chunks", "model", "llm"]
230
- missing_keys = [key for key in required_keys if key not in session]
231
- if missing_keys:
232
- print(f"Warning: Session {request.session_id} missing required data: {missing_keys}")
233
- raise HTTPException(status_code=500, detail="Session data is incomplete. Please upload the document again.")
234
-
235
- # Per-request memory to ensure chat history is correctly loaded for the agent
236
- agent_memory = ConversationBufferMemory(memory_key="chat_history", input_key="input", return_messages=True)
237
- for entry in session.get("chat_history", []):
238
- agent_memory.chat_memory.add_user_message(entry["user"])
239
- agent_memory.chat_memory.add_ai_message(entry["assistant"])
240
-
241
- # Prepare tools for the agent for THIS request
242
- current_request_tools = []
243
-
244
- # 1. Add the document-specific vector search tool
245
- vector_search_tool_instance = create_vector_search_tool(
246
- faiss_index=session["index"],
247
- document_chunks_with_metadata=session["chunks"], # Pass the correct variable
248
- embedding_model=session["model"], # This is the SentenceTransformer model
249
- max_chunk_length=1000,
250
- k=10
251
- )
252
- current_request_tools.append(vector_search_tool_instance)
253
-
254
- # 2. Conditionally add Tavily (web search) tool
255
- if request.use_search:
256
- if os.getenv("TAVILY_API_KEY"):
257
- tavily_tool = next((tool for tool in global_base_tools if tool.name == "tavily_search_results_json"), None)
258
- if tavily_tool:
259
- current_request_tools.append(tavily_tool)
260
- else: # Should not happen if global_base_tools is defined correctly
261
- print("Warning: Tavily search requested, but tool misconfigured.")
262
- else:
263
- print("Warning: Tavily search requested, but TAVILY_API_KEY is not set.")
264
-
265
- # Retrieve initial similar chunks for RAG context (can be empty if no good match)
266
- # This context is given to the agent *before* it decides to use tools.
267
- # k=5 means we retrieve up to 5 chunks for initial context.
268
- # The agent can then use `vector_database_search` to search more if needed.
269
- initial_similar_chunks = retrieve_similar_chunks(
270
- request.query,
271
- session["index"],
272
- session["chunks"], # list of dicts {text:..., metadata:...}
273
- session["model"], # SentenceTransformer model
274
- k=5 # Number of chunks for initial context
275
- )
276
-
277
- print(f"Query: '{request.query}' - Found {len(initial_similar_chunks)} initial chunks")
278
- if initial_similar_chunks:
279
- print(f"Best chunk score: {initial_similar_chunks[0][1]:.4f}")
280
-
281
- response = agentic_rag(
282
- session["llm"],
283
- current_request_tools, # Pass the dynamically assembled list of tools
284
- query=request.query,
285
- context_chunks=initial_similar_chunks,
286
- Use_Tavily=request.use_search, # Still passed to agentic_rag for potential fine-grained logic, though prompt adapts to tools
287
- memory=agent_memory
288
- )
289
-
290
- response_output = response.get("output", "Sorry, I could not generate a response.")
291
- print(f"Generated response length: {len(response_output)} characters")
292
-
293
- session["chat_history"].append({"user": request.query, "assistant": response_output})
294
- save_session(request.session_id, session) # Save updated history and potentially other modified session state
295
-
296
- return {
297
- "status": "success",
298
- "answer": response_output,
299
- # Return context that was PRE-FETCHED for the agent, not necessarily all context it might have used via tools
300
- "context_used": [{"text": chunk, "score": float(score), "metadata": meta} for chunk, score, meta in initial_similar_chunks]
301
- }
302
-
303
- except Exception as e:
304
- print(f"Error processing chat query: {str(e)}\nTraceback: {traceback.format_exc()}")
305
- raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
306
-
307
-
308
- # Route to get chat history
309
- @app.post("/chat-history")
310
- async def get_chat_history(request: SessionRequest):
311
- # Try to load session if not in memory
312
- session, found = load_session(request.session_id)
313
- if not found:
314
- raise HTTPException(status_code=404, detail="Session not found")
315
-
316
- return {
317
- "status": "success",
318
- "history": session.get("chat_history", [])
319
- }
320
-
321
- # Route to clear chat history
322
- @app.post("/clear-history")
323
- async def clear_history(request: SessionRequest):
324
- # Try to load session if not in memory
325
- session, found = load_session(request.session_id)
326
- if not found:
327
- raise HTTPException(status_code=404, detail="Session not found")
328
-
329
- session["chat_history"] = []
330
- save_session(request.session_id, session)
331
-
332
- return {"status": "success", "message": "Chat history cleared"}
333
-
334
- # Route to remove PDF from session
335
- @app.post("/remove-pdf")
336
- async def remove_pdf(request: SessionRequest):
337
- success = remove_pdf_file(request.session_id)
338
-
339
- if success:
340
- return {"status": "success", "message": "PDF file and session removed successfully"}
341
- else:
342
- raise HTTPException(status_code=404, detail="Session not found or could not be removed")
343
-
344
- # Route to list available models
345
- @app.get("/models")
346
- async def get_models():
347
- # You can expand this list as needed
348
- models = [
349
- {"id": "meta-llama/llama-4-scout-17b-16e-instruct", "name": "Llama 4 Scout 17B"},
350
- {"id": "llama-3.1-8b-instant", "name": "Llama 3.1 8B Instant"},
351
- {"id": "llama-3.3-70b-versatile", "name": "Llama 3.3 70B Versatile"},
352
- ]
353
- return {"models": models}
354
-
355
- # Run the application if this file is executed directly
356
- if __name__ == "__main__":
357
- uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_refactored.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Refactored FastAPI application for PDF Insight Beta.
3
+
4
+ This is the main application file that sets up the FastAPI app with modular components.
5
+ The core logic has been preserved while improving code organization and maintainability.
6
+ """
7
+
8
+ import uvicorn
9
+ from fastapi import FastAPI, UploadFile, File, Form
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.staticfiles import StaticFiles
12
+
13
+ from configs.config import Config
14
+ from models.models import (
15
+ ChatRequest, SessionRequest, UploadResponse, ChatResponse,
16
+ ChatHistoryResponse, StatusResponse, ModelsResponse
17
+ )
18
+ from api import (
19
+ upload_pdf_handler, chat_handler, get_chat_history_handler,
20
+ clear_history_handler, remove_pdf_handler, root_handler, get_models_handler
21
+ )
22
+
23
+
24
+ def create_app() -> FastAPI:
25
+ """
26
+ Create and configure the FastAPI application.
27
+
28
+ Returns:
29
+ Configured FastAPI application instance
30
+ """
31
+ # Initialize FastAPI app
32
+ app = FastAPI(
33
+ title="PDF Insight Beta",
34
+ description="Agentic RAG for PDF documents"
35
+ )
36
+
37
+ # Add CORS middleware
38
+ app.add_middleware(
39
+ CORSMiddleware,
40
+ allow_origins=Config.CORS_ORIGINS,
41
+ allow_credentials=Config.CORS_CREDENTIALS,
42
+ allow_methods=Config.CORS_METHODS,
43
+ allow_headers=Config.CORS_HEADERS,
44
+ )
45
+
46
+ # Mount static files
47
+ app.mount("/static", StaticFiles(directory="static"), name="static")
48
+
49
+ return app
50
+
51
+
52
+ # Create app instance
53
+ app = create_app()
54
+
55
+
56
+ # Route definitions
57
+ @app.get("/")
58
+ async def read_root():
59
+ """Root endpoint that redirects to the main application."""
60
+ return await root_handler()
61
+
62
+
63
+ @app.post("/upload-pdf", response_model=UploadResponse)
64
+ async def upload_pdf(file: UploadFile = File(...), model_name: str = Form(Config.DEFAULT_MODEL)):
65
+ """Upload and process a PDF file."""
66
+ return await upload_pdf_handler(file, model_name)
67
+
68
+
69
+ @app.post("/chat", response_model=ChatResponse)
70
+ async def chat(request: ChatRequest):
71
+ """Chat with the uploaded document."""
72
+ return await chat_handler(request)
73
+
74
+
75
+ @app.post("/chat-history", response_model=ChatHistoryResponse)
76
+ async def get_chat_history(request: SessionRequest):
77
+ """Get chat history for a session."""
78
+ return await get_chat_history_handler(request)
79
+
80
+
81
+ @app.post("/clear-history", response_model=StatusResponse)
82
+ async def clear_history(request: SessionRequest):
83
+ """Clear chat history for a session."""
84
+ return await clear_history_handler(request)
85
+
86
+
87
+ @app.post("/remove-pdf", response_model=StatusResponse)
88
+ async def remove_pdf(request: SessionRequest):
89
+ """Remove PDF file and session data."""
90
+ return await remove_pdf_handler(request)
91
+
92
+
93
+ @app.get("/models", response_model=ModelsResponse)
94
+ async def get_models():
95
+ """Get list of available models."""
96
+ return await get_models_handler()
97
+
98
+
99
+ def main():
100
+ """
101
+ Main entry point for running the application.
102
+ """
103
+ uvicorn.run("app_refactored:app", host="0.0.0.0", port=8000, reload=True)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ main()