zamal commited on
Commit
3ad87bd
·
verified ·
1 Parent(s): 6177313

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -75
app.py CHANGED
@@ -41,6 +41,9 @@ load_dotenv()
41
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
42
  processor = None
43
  vision_model = None
 
 
 
44
  # OCR + multimodal image description setup
45
  ocr_model = ocr_predictor(
46
  "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
@@ -97,52 +100,34 @@ SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
97
  )
98
 
99
  def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
100
- """
101
- Build a *persistent* ChromaDB instance on disk, with two collections:
102
- text_db (chunks of the PDF text)
103
- image_db (image descriptions + raw image bytes)
104
- """
105
- # 1) Make or clean the on-disk folder
106
- shutil.rmtree(PERSIST_DIR, ignore_errors=True)
107
- os.makedirs(PERSIST_DIR, exist_ok=True)
108
-
109
- client = chromadb.PersistentClient(
110
- path=PERSIST_DIR,
111
- settings=Settings(),
112
- tenant=DEFAULT_TENANT,
113
- database=DEFAULT_DATABASE
114
- )
115
-
116
- # 3) Create / wipe collections
117
- for col in ("text_db", "image_db"):
118
- if col in [c.name for c in client.list_collections()]:
119
- client.delete_collection(col)
120
-
121
- text_col = client.get_or_create_collection(
122
- name="text_db",
123
- embedding_function=SHARED_EMB_FN
124
- )
125
- img_col = client.get_or_create_collection(
126
- name="image_db",
127
  embedding_function=SHARED_EMB_FN,
128
  metadata={"hnsw:space": "cosine"}
129
  )
130
 
131
- # 4) Add images
132
  if images:
133
  descs, metas = [], []
134
- for idx, img in enumerate(images):
135
  try:
136
  cap = get_image_description(img)
137
  except:
138
  cap = "⚠️ could not describe image"
139
- descs.append(f"{img_names[idx]}: {cap}")
140
  metas.append({"image": image_to_bytes(img)})
141
  img_col.add(ids=[str(i) for i in range(len(images))],
142
  documents=descs,
143
  metadatas=metas)
144
 
145
- # 5) Chunk & add text
146
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
147
  docs = splitter.create_documents([text])
148
  text_col.add(ids=[str(i) for i in range(len(docs))],
@@ -153,6 +138,7 @@ def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
153
 
154
 
155
 
 
156
  # Text extraction
157
  def result_to_text(result, as_text=False):
158
  pages = []
@@ -224,15 +210,15 @@ def extract_data_from_pdfs(
224
  progress(0.6, "Indexing in vector DB…")
225
  client = get_vectordb(all_text, images, names)
226
 
227
- # 6) Mark session and return UI outputs
228
- session["processed"] = True
229
- session["persist_directory"] = PERSIST_DIR
230
- sample_imgs = images[:4] if include_images == "Include Images" else []
231
 
 
 
232
  return (
233
- session, # gr.State
234
  all_text[:2000] + "...",
235
- sample_imgs,
236
  "<h3>Done!</h3>"
237
  )
238
 
@@ -250,49 +236,41 @@ def conversation(
250
  max_tok: int,
251
  model_id: str
252
  ):
253
- pd = session.get("persist_directory")
254
- if not session.get("processed") or not pd:
 
 
 
 
 
255
  raise gr.Error("Please extract data first")
256
 
257
- # 1) Reopen the same persistent client (new API)
258
- client = chromadb.PersistentClient(
259
- path=pd,
260
- settings=Settings(),
261
- tenant=DEFAULT_TENANT,
262
- database=DEFAULT_DATABASE
 
 
 
 
 
 
 
 
263
  )
264
-
265
-
266
- # 2) Text retrieval
267
- text_col = client.get_collection("text_db")
268
- docs = text_col.query(query_texts=[question],
269
- n_results=int(num_ctx),
270
- include=["documents"])["documents"][0]
271
-
272
- # 3) Image retrieval
273
- img_col = client.get_collection("image_db")
274
- img_q = img_col.query(query_texts=[question],
275
- n_results=int(img_ctx),
276
- include=["metadatas","documents"])
277
  img_descs = img_q["documents"][0] or ["No images found"]
278
  images = []
279
  for meta in img_q["metadatas"][0]:
280
- b64 = meta.get("image","")
281
  try:
282
  images.append(Image.open(io.BytesIO(base64.b64decode(b64))))
283
  except:
284
  pass
285
  img_desc = "\n".join(img_descs)
286
 
287
- # 4) Build prompt & call LLM
288
- llm = HuggingFaceEndpoint(
289
- repo_id=model_id,
290
- task="text-generation",
291
- temperature=temp,
292
- max_new_tokens=max_tok,
293
- huggingfacehub_api_token=HF_TOKEN
294
- )
295
-
296
  prompt = PromptTemplate(
297
  template="""
298
  Context:
@@ -305,27 +283,43 @@ Question:
305
  {q}
306
 
307
  Answer:
308
- """, input_variables=["text","img_desc","q"]
 
 
 
 
 
 
309
  )
310
- inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question)
311
 
 
 
 
 
 
 
 
 
312
  try:
313
- answer = llm.invoke(inp)
314
  except HfHubHTTPError as e:
315
- answer = "❌ Model not hosted" if e.response.status_code==404 else f"⚠️ HF error: {e}"
 
 
 
316
  except Exception as e:
317
  answer = f"⚠️ Unexpected error: {e}"
318
 
 
319
  new_history = history + [
320
- {"role":"user", "content":question},
321
- {"role":"assistant","content":answer}
322
  ]
323
  return new_history, docs, images
324
 
325
 
326
 
327
 
328
-
329
  # ─────────────────────────────────────────────────────────────────────────────
330
  # Gradio UI
331
  CSS = """
 
41
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
42
  processor = None
43
  vision_model = None
44
+ # hold the in-memory vectordb
45
+ CURRENT_VDB = None
46
+
47
  # OCR + multimodal image description setup
48
  ocr_model = ocr_predictor(
49
  "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
 
100
  )
101
 
102
  def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
103
+ client = chromadb.EphemeralClient()
104
+ # wipe old
105
+ for name in ("text_db", "image_db"):
106
+ if name in [c.name for c in client.list_collections()]:
107
+ client.delete_collection(name)
108
+
109
+ text_col = client.get_or_create_collection("text_db", embedding_function=SHARED_EMB_FN)
110
+ img_col = client.get_or_create_collection(
111
+ "image_db",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  embedding_function=SHARED_EMB_FN,
113
  metadata={"hnsw:space": "cosine"}
114
  )
115
 
116
+ # add images
117
  if images:
118
  descs, metas = [], []
119
+ for i, img in enumerate(images):
120
  try:
121
  cap = get_image_description(img)
122
  except:
123
  cap = "⚠️ could not describe image"
124
+ descs.append(f"{img_names[i]}: {cap}")
125
  metas.append({"image": image_to_bytes(img)})
126
  img_col.add(ids=[str(i) for i in range(len(images))],
127
  documents=descs,
128
  metadatas=metas)
129
 
130
+ # chunk + add text
131
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
132
  docs = splitter.create_documents([text])
133
  text_col.add(ids=[str(i) for i in range(len(docs))],
 
138
 
139
 
140
 
141
+
142
  # Text extraction
143
  def result_to_text(result, as_text=False):
144
  pages = []
 
210
  progress(0.6, "Indexing in vector DB…")
211
  client = get_vectordb(all_text, images, names)
212
 
213
+ global CURRENT_VDB
214
+ CURRENT_VDB = get_vectordb(all_text, images, names)
 
 
215
 
216
+ session["processed"] = True
217
+ sample = images[:4] if include_images=="Include Images" else []
218
  return (
219
+ session,
220
  all_text[:2000] + "...",
221
+ sample,
222
  "<h3>Done!</h3>"
223
  )
224
 
 
236
  max_tok: int,
237
  model_id: str
238
  ):
239
+ """
240
+ Uses the in-memory CURRENT_VDB (set by extract_data_from_pdfs) to answer the user.
241
+ """
242
+ global CURRENT_VDB
243
+
244
+ # 0) Guard: make sure we've extracted at least once
245
+ if not session.get("processed") or CURRENT_VDB is None:
246
  raise gr.Error("Please extract data first")
247
 
248
+ # 1) Retrieve top-k text chunks
249
+ text_col = CURRENT_VDB.get_collection("text_db")
250
+ docs = text_col.query(
251
+ query_texts=[question],
252
+ n_results=int(num_ctx),
253
+ include=["documents"]
254
+ )["documents"][0]
255
+
256
+ # 2) Retrieve top-k images
257
+ img_col = CURRENT_VDB.get_collection("image_db")
258
+ img_q = img_col.query(
259
+ query_texts=[question],
260
+ n_results=int(img_ctx),
261
+ include=["metadatas", "documents"]
262
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  img_descs = img_q["documents"][0] or ["No images found"]
264
  images = []
265
  for meta in img_q["metadatas"][0]:
266
+ b64 = meta.get("image", "")
267
  try:
268
  images.append(Image.open(io.BytesIO(base64.b64decode(b64))))
269
  except:
270
  pass
271
  img_desc = "\n".join(img_descs)
272
 
273
+ # 3) Build the prompt
 
 
 
 
 
 
 
 
274
  prompt = PromptTemplate(
275
  template="""
276
  Context:
 
283
  {q}
284
 
285
  Answer:
286
+ """,
287
+ input_variables=["text", "img_desc", "q"],
288
+ )
289
+ user_input = prompt.format(
290
+ text="\n\n".join(docs),
291
+ img_desc=img_desc,
292
+ q=question
293
  )
 
294
 
295
+ # 4) Call the LLM
296
+ llm = HuggingFaceEndpoint(
297
+ repo_id=model_id,
298
+ task="text-generation",
299
+ temperature=temp,
300
+ max_new_tokens=max_tok,
301
+ # the client will pick up HUGGINGFACEHUB_API_TOKEN from env automatically
302
+ )
303
  try:
304
+ answer = llm.invoke(user_input)
305
  except HfHubHTTPError as e:
306
+ if e.response.status_code == 404:
307
+ answer = f"❌ Model `{model_id}` not hosted on HF Inference API."
308
+ else:
309
+ answer = f"⚠️ HF API error: {e}"
310
  except Exception as e:
311
  answer = f"⚠️ Unexpected error: {e}"
312
 
313
+ # 5) Append to chat history and return
314
  new_history = history + [
315
+ {"role": "user", "content": question},
316
+ {"role": "assistant", "content": answer}
317
  ]
318
  return new_history, docs, images
319
 
320
 
321
 
322
 
 
323
  # ─────────────────────────────────────────────────────────────────────────────
324
  # Gradio UI
325
  CSS = """