syedMohib44 commited on
Commit
729a6b2
·
1 Parent(s): 4be8fe7
Files changed (1) hide show
  1. app.py +234 -63
app.py CHANGED
@@ -1,83 +1,254 @@
1
- import os
2
  import json
3
- from fastapi import FastAPI
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
- from typing import List
7
- from transformers import pipeline
8
- from sentence_transformers import SentenceTransformer
9
  import faiss
10
- import gradio as gr
11
- from gradio import mount_gradio_app
 
 
 
 
12
 
13
- # ------------------- Config ------------------- #
14
- DATA_PATH = "/tmp/pentagon_core.json" # Use /tmp for temporary storage
 
 
 
15
  EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2"
16
- QA_MODEL = "./models/bart-large-cnn"
17
- DEVICE = "cuda" if os.environ.get("USE_CUDA") == "1" else "cpu"
18
-
19
- # ------------------- Load Models ------------------- #
20
- embedder = SentenceTransformer(EMBEDDING_MODEL)
21
- qa_model = pipeline("text2text-generation", model=QA_MODEL, device=0 if DEVICE == "cuda" else -1)
22
-
23
- # ------------------- Load Dataset + Index ------------------- #
24
- if os.path.exists(DATA_PATH):
25
- with open(DATA_PATH, "r") as f:
26
- knowledge_base = json.load(f)
27
- else:
28
- knowledge_base = []
29
-
30
- texts = [item["content"] for item in knowledge_base]
31
- embeddings = embedder.encode(texts, convert_to_tensor=True)
32
- index = faiss.IndexFlatL2(embeddings.shape[1])
33
- index.add(embeddings.cpu().detach().numpy())
34
-
35
- # ------------------- FastAPI App ------------------- #
36
- app = FastAPI()
37
- app.add_middleware(
38
- CORSMiddleware,
39
- allow_origins=["*"], # For development
40
- allow_credentials=True,
41
- allow_methods=["*"],
42
- allow_headers=["*"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  )
44
 
45
- # --------- Upload Endpoint --------- #
46
- class UploadData(BaseModel):
47
- content: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  @app.post("/upload/")
50
- def upload_knowledge(data: UploadData):
51
- global knowledge_base, index
52
 
53
- knowledge_base.append({"content": data.content})
54
- with open(DATA_PATH, "w") as f:
55
- json.dump(knowledge_base, f, indent=2)
56
 
57
- new_embedding = embedder.encode([data.content], convert_to_numpy=True)
58
- index.add(new_embedding)
 
 
 
 
 
 
 
 
 
59
 
60
- return {"message": "Data uploaded and indexed."}
 
 
 
 
 
 
 
 
 
 
61
 
62
- # --------- Ask Endpoint --------- #
63
- @app.get("/ask/")
64
- def ask(question: str, top_k: int = 3):
65
- question_embedding = embedder.encode([question], convert_to_numpy=True)
66
- distances, indices = index.search(question_embedding, top_k)
67
- context = " ".join([knowledge_base[i]["content"] for i in indices[0]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  prompt = (
70
- f"Context: {context}\n\n"
71
  f"Answer the following question based only on the above context:\n"
72
- f"{question}\n\nAnswer:"
 
73
  )
74
- output = qa_model(prompt, max_length=256, do_sample=False)[0]["generated_text"]
75
 
76
- return {
77
- "question": question,
78
- "context_used": context,
79
- "answer": output.strip()
80
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # --------- Gradio UI --------- #
83
  def gradio_upload(file):
 
1
+ import torch
2
  import json
3
+ import os
 
 
 
 
 
4
  import faiss
5
+ import numpy as np
6
+ from pptx import Presentation
7
+ from fastapi import FastAPI, UploadFile, File
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
+ from sentence_transformers import SentenceTransformer
10
+ from io import BytesIO
11
 
12
+ # ---------------------------- #
13
+ # CONFIGURATION
14
+ # ---------------------------- #
15
+ MODEL_NAME = "./models/facebook-opt-1.3b"
16
+ SUMMARIZATION_MODEL = "./models/bart-large-cnn"
17
  EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2"
18
+ DATA_DIRECTORY = "./dataset/"
19
+
20
+ # ---------------------------- #
21
+ # FUNCTION TO LOAD JSON FILES
22
+ # ---------------------------- #
23
+ def load_text_from_json(directory):
24
+ text_data = set() # Use set to remove duplicates
25
+
26
+ for filename in os.listdir(directory):
27
+ if filename.endswith(".json"):
28
+ with open(os.path.join(directory, filename), "r", encoding="utf-8") as file:
29
+ data = json.load(file)
30
+ for entry in data.get("data", []):
31
+ question = entry.get("question", "").strip()
32
+ answer = entry.get("answer", "").strip()
33
+ if question and answer:
34
+ text_data.add(f"Q: {question} A: {answer}")
35
+
36
+ return list(text_data)
37
+
38
+ # ---------------------------- #
39
+ # FUNCTION TO LOAD POWERPOINT FILES
40
+ # ---------------------------- #
41
+ def extract_text_from_pptx(file_path):
42
+ prs = Presentation(file_path)
43
+ text_data = []
44
+
45
+ for slide in prs.slides:
46
+ for shape in slide.shapes:
47
+ if hasattr(shape, "text"):
48
+ text_data.append(shape.text.strip())
49
+
50
+ return " ".join(text_data)
51
+
52
+ def load_text_from_pptx(directory):
53
+ text_data = set()
54
+
55
+ for filename in os.listdir(directory):
56
+ if filename.endswith(".pptx"):
57
+ pptx_text = extract_text_from_pptx(os.path.join(directory, filename))
58
+ text_data.add(pptx_text)
59
+
60
+ return list(text_data)
61
+
62
+ # ---------------------------- #
63
+ # LOAD ALL TEXT DATA
64
+ # ---------------------------- #
65
+ all_text = load_text_from_json(DATA_DIRECTORY) + load_text_from_pptx(DATA_DIRECTORY)
66
+
67
+ # ---------------------------- #
68
+ # CHUNK DATA PROPERLY
69
+ # ---------------------------- #
70
+ CHUNK_SIZE = 500
71
+ chunks = set()
72
+
73
+ for text in all_text:
74
+ sentences = text.split(". ")
75
+ temp_chunk = ""
76
+
77
+ for sentence in sentences:
78
+ if len(temp_chunk) + len(sentence) < CHUNK_SIZE:
79
+ temp_chunk += sentence + ". "
80
+ else:
81
+ chunks.add(temp_chunk.strip()) # Store chunk
82
+ temp_chunk = sentence + ". "
83
+
84
+ if temp_chunk:
85
+ chunks.add(temp_chunk.strip()) # Store last chunk
86
+
87
+ chunks = list(chunks) # Convert to list after deduplication
88
+
89
+ # ---------------------------- #
90
+ # EMBEDDING MODEL & FAISS VECTOR SEARCH
91
+ # ---------------------------- #
92
+ embedder = SentenceTransformer(EMBEDDING_MODEL, local_files_only=True)
93
+ chunk_embeddings = embedder.encode(chunks, convert_to_numpy=True)
94
+
95
+ # FAISS index
96
+ index = faiss.IndexFlatL2(chunk_embeddings.shape[1])
97
+ index.add(chunk_embeddings)
98
+
99
+ # ---------------------------- #
100
+ # LOAD LLM MODEL
101
+ # ---------------------------- #
102
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
105
  )
106
 
107
+ # Summarization pipeline
108
+ summarizer = pipeline("summarization", model=SUMMARIZATION_MODEL)
109
+
110
+ # ---------------------------- #
111
+ # FASTAPI SETUP
112
+ # ---------------------------- #
113
+ app = FastAPI()
114
+
115
+ def retrieve_relevant_text(question, top_k=3):
116
+ question_embedding = embedder.encode([question], convert_to_numpy=True)
117
+ _, idxs = index.search(question_embedding, top_k)
118
+
119
+ retrieved_texts = [chunks[idx] for idx in idxs[0]]
120
+
121
+ # Filter out chunks that contain the same question
122
+ filtered_chunks = [text for text in retrieved_texts if question.lower() not in text.lower()]
123
+ unique_texts = list(set(filtered_chunks))
124
+
125
+ context_text = " ".join(unique_texts)
126
+ if len(context_text) > 1000:
127
+ context_text = summarizer(context_text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
128
+
129
+ return context_text
130
+
131
 
132
  @app.post("/upload/")
133
+ async def upload_file(file: UploadFile = File(...)):
134
+ global chunks, index, chunk_embeddings
135
 
136
+ filename = file.filename
137
+ content = await file.read()
138
+ new_texts = []
139
 
140
+ try:
141
+ # -------------------- #
142
+ # Process .json files
143
+ # -------------------- #
144
+ if filename.endswith(".json"):
145
+ data = json.loads(content)
146
+ for entry in data.get("data", []):
147
+ question = entry.get("question", "").strip()
148
+ answer = entry.get("answer", "").strip()
149
+ if question and answer:
150
+ new_texts.append(f"Q: {question} A: {answer}")
151
 
152
+ # -------------------- #
153
+ # Process .pptx files
154
+ # -------------------- #
155
+ elif filename.endswith(".pptx"):
156
+ prs = Presentation(BytesIO(content))
157
+ ppt_text = []
158
+ for slide in prs.slides:
159
+ for shape in slide.shapes:
160
+ if hasattr(shape, "text"):
161
+ ppt_text.append(shape.text.strip())
162
+ new_texts.append(" ".join(ppt_text))
163
 
164
+ else:
165
+ return {"error": "Unsupported file type. Use .json or .pptx"}
166
+
167
+ # -------------------- #
168
+ # Chunk and embed
169
+ # -------------------- #
170
+ new_chunks = set()
171
+ for text in new_texts:
172
+ sentences = text.split(". ")
173
+ temp = ""
174
+ for s in sentences:
175
+ if len(temp) + len(s) < CHUNK_SIZE:
176
+ temp += s + ". "
177
+ else:
178
+ new_chunks.add(temp.strip())
179
+ temp = s + ". "
180
+ if temp:
181
+ new_chunks.add(temp.strip())
182
+
183
+ # Remove existing chunks (dedup)
184
+ new_chunks = list(set(new_chunks) - set(chunks))
185
+
186
+ if not new_chunks:
187
+ return {"message": "No new unique chunks to add."}
188
+
189
+ # Encode and update FAISS
190
+ new_embeddings = embedder.encode(new_chunks, convert_to_numpy=True)
191
+ index.add(new_embeddings)
192
+ chunks.extend(new_chunks)
193
+
194
+ return {
195
+ "status": "success",
196
+ "new_chunks_added": len(new_chunks),
197
+ "total_chunks": len(chunks)
198
+ }
199
+
200
+ except Exception as e:
201
+ return {"error": str(e)}
202
+
203
+
204
+ @app.get("/faq/")
205
+ def faq(question: str):
206
+ """Answer user queries using retrieved knowledge."""
207
+ retrieved_text = retrieve_relevant_text(question)
208
 
209
  prompt = (
210
+ f"{retrieved_text.strip()}\n\n"
211
  f"Answer the following question based only on the above context:\n"
212
+ f"{question.strip()}\n\n"
213
+ f"Answer:"
214
  )
 
215
 
216
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
217
+
218
+ with torch.no_grad():
219
+ output = model.generate(
220
+ **inputs,
221
+ max_length=200,
222
+ repetition_penalty=1.3,
223
+ no_repeat_ngram_size=4,
224
+ temperature=0.7,
225
+ do_sample=False,
226
+ pad_token_id=tokenizer.eos_token_id,
227
+ )
228
+
229
+ raw_answer = tokenizer.decode(output[0], skip_special_tokens=True)
230
+
231
+ # ---------------------------- #
232
+ # POST-PROCESSING CLEANUP
233
+ # ---------------------------- #
234
+ cleaned_answer = raw_answer
235
+
236
+ # Remove the prompt (everything before final 'Answer:' keyword)
237
+ if "Answer:" in cleaned_answer:
238
+ cleaned_answer = cleaned_answer.split("Answer:")[-1]
239
+
240
+ # Remove repeated question (case-insensitive)
241
+ question_lower = question.strip().lower()
242
+ cleaned_answer = cleaned_answer.strip()
243
+ if cleaned_answer.lower().startswith(question_lower):
244
+ cleaned_answer = cleaned_answer[len(question):].strip()
245
+
246
+ # Final touch: remove context/prompt tokens if they leaked
247
+ for token in ["Context:", "Question:", "Answer:"]:
248
+ cleaned_answer = cleaned_answer.replace(token, "").strip()
249
+
250
+ return {"answer": cleaned_answer}
251
+
252
 
253
  # --------- Gradio UI --------- #
254
  def gradio_upload(file):