zamal commited on
Commit
08d9c00
·
verified ·
1 Parent(s): 82895ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -48
app.py CHANGED
@@ -58,12 +58,12 @@ CURRENT_VDB = None
58
  @spaces.GPU()
59
  def get_image_description(image: Image.Image) -> str:
60
  """
61
- Lazy-loads the Llava processor + model into the GPU worker,
62
  runs captioning, and returns a one-sentence description.
63
  """
64
  global processor, vision_model
65
 
66
- # First-call: instantiate + move to CUDA
67
  if processor is None or vision_model is None:
68
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
69
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
@@ -72,9 +72,9 @@ def get_image_description(image: Image.Image) -> str:
72
  low_cpu_mem_usage=True
73
  ).to("cuda")
74
 
75
- # clear and run
76
  torch.cuda.empty_cache()
77
  gc.collect()
 
78
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
79
  inputs = processor(prompt, image, return_tensors="pt").to("cuda")
80
  output = vision_model.generate(**inputs, max_new_tokens=100)
@@ -175,21 +175,21 @@ def extract_data_from_pdfs(
175
  ):
176
  """
177
  1) (Optional) OCR setup
178
- 2) V+L model setup & monkey-patch get_image_description
179
- 3) Extract text and images
180
- 4) Build and store vector DB in global CURRENT_VDB
181
  """
182
  if not docs:
183
  raise gr.Error("No documents to process")
184
 
185
- # 1) OCR instantiation if requested
186
  if do_ocr == "Get Text With OCR":
187
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
188
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
189
  else:
190
  local_ocr = None
191
 
192
- # 2) Vision–language model instantiation
193
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
194
  vis = (
195
  LlavaNextForConditionalGeneration
@@ -197,9 +197,10 @@ def extract_data_from_pdfs(
197
  .to("cuda")
198
  )
199
 
200
- # Monkey-patch global captioning fn
201
  def describe(img: Image.Image) -> str:
202
- torch.cuda.empty_cache(); gc.collect()
 
203
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
204
  inputs = proc(prompt, img, return_tensors="pt").to("cuda")
205
  output = vis.generate(**inputs, max_new_tokens=100)
@@ -208,13 +209,12 @@ def extract_data_from_pdfs(
208
  global get_image_description, CURRENT_VDB
209
  get_image_description = describe
210
 
211
- # 3) Extract text & images
212
  progress(0.2, "Extracting text and images…")
213
  all_text = ""
214
  images, names = [], []
215
 
216
  for path in docs:
217
- # text
218
  if local_ocr:
219
  pdf = DocumentFile.from_pdf(path)
220
  res = local_ocr(pdf)
@@ -223,29 +223,28 @@ def extract_data_from_pdfs(
223
  txt = PdfReader(path).pages[0].extract_text() or ""
224
  all_text += txt + "\n\n"
225
 
226
- # images
227
  if include_images == "Include Images":
228
  imgs = extract_images([path])
229
  images.extend(imgs)
230
  names.extend([os.path.basename(path)] * len(imgs))
231
 
232
- # 4) Build and stash the vector DB
233
  progress(0.6, "Indexing in vector DB…")
234
  CURRENT_VDB = get_vectordb(all_text, images, names)
235
 
236
- # mark done & return only picklable outputs
237
  session["processed"] = True
238
  sample_imgs = images[:4] if include_images == "Include Images" else []
239
 
 
240
  return (
241
- session, # gr.State for “processed”
242
- gr.Row(visible=True), # to un‐hide your chat UI
243
- all_text[:2000] + "...",
244
- sample_imgs,
245
- "<h3>Done!</h3>"
246
  )
247
 
248
 
 
249
  # Chat function
250
  def conversation(
251
  session: dict,
@@ -258,8 +257,7 @@ def conversation(
258
  model_id: str
259
  ):
260
  """
261
- Pulls CURRENT_VDB from module global, runs text+image retrieval,
262
- calls the HF endpoint, and returns updated chat history.
263
  """
264
  global CURRENT_VDB
265
  if not session.get("processed") or CURRENT_VDB is None:
@@ -272,7 +270,7 @@ def conversation(
272
  huggingfacehub_api_token=HF_TOKEN
273
  )
274
 
275
- # Retrieve top‐k text & images
276
  text_col = CURRENT_VDB.get_collection("text_db")
277
  docs = text_col.query(
278
  query_texts=[question],
@@ -280,6 +278,7 @@ def conversation(
280
  include=["documents"]
281
  )["documents"][0]
282
 
 
283
  img_col = CURRENT_VDB.get_collection("image_db")
284
  img_q = img_col.query(
285
  query_texts=[question],
@@ -296,7 +295,7 @@ def conversation(
296
  pass
297
  img_desc = "\n".join(img_descs)
298
 
299
- # Build and call prompt
300
  prompt = PromptTemplate(
301
  template="""
302
  Context:
@@ -336,6 +335,7 @@ Answer:
336
 
337
 
338
 
 
339
  # ─────────────────────────────────────────────────────────────────────────────
340
  # Gradio UI
341
  CSS = """
@@ -357,14 +357,13 @@ MODEL_OPTIONS = [
357
  ]
358
 
359
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
360
- vdb_state = gr.State()
361
  session_state = gr.State({})
362
 
363
  # ─── Welcome Screen ─────────────────────────────────────────────
364
  with gr.Column(visible=True) as welcome_col:
365
-
366
  gr.Markdown(
367
- f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>",
368
  elem_id="welcome_md"
369
  )
370
  start_btn = gr.Button("🚀 Start")
@@ -386,6 +385,11 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
386
  value="Exclude Images",
387
  label="Images"
388
  )
 
 
 
 
 
389
  ocr_dd = gr.Dropdown(
390
  choices=[
391
  "db_resnet50 + crnn_mobilenet_v3_large",
@@ -405,28 +409,23 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
405
  extract_btn = gr.Button("Extract")
406
  preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
407
  preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
 
408
 
409
  extract_btn.click(
410
- extract_data_from_pdfs,
411
  inputs=[
412
  docs,
413
  session_state,
414
  include_dd,
415
- gr.Radio(
416
- ["Get Text With OCR", "Get Available Text Only"],
417
- value="Get Available Text Only",
418
- label="OCR"
419
- ),
420
  ocr_dd,
421
  vlm_dd
422
  ],
423
  outputs=[
424
- vdb_state,
425
- session_state,
426
- gr.Row(visible=False),
427
- preview_text,
428
- preview_img,
429
- gr.HTML()
430
  ]
431
  )
432
 
@@ -446,15 +445,15 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
446
  value=MODEL_OPTIONS[0],
447
  label="Choose Chat Model"
448
  )
449
- num_ctx = gr.Slider(1,20,value=3,label="Text Contexts")
450
- img_ctx = gr.Slider(1,10,value=2,label="Image Contexts")
451
- temp = gr.Slider(0.1,1.0,step=0.1,value=0.4,label="Temperature")
452
- max_tok = gr.Slider(10,1000,step=10,value=200,label="Max Tokens")
453
 
454
  send.click(
455
- conversation,
456
  inputs=[
457
- vdb_state,
458
  msg,
459
  num_ctx,
460
  img_ctx,
@@ -465,18 +464,18 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
465
  ],
466
  outputs=[
467
  chat,
468
- gr.Dataframe(),
469
  gr.Gallery(label="Relevant Images", rows=2, value=[])
470
  ]
471
  )
472
 
473
- # Footer inside app_col
474
  gr.HTML("<center>Made with ❤️ by Zamal</center>")
475
 
476
  # ─── Wire the Start button ───────────────────────────────────────
477
  start_btn.click(
478
  fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
479
- inputs=[], outputs=[welcome_col, app_col]
 
480
  )
481
 
482
  if __name__ == "__main__":
 
58
  @spaces.GPU()
59
  def get_image_description(image: Image.Image) -> str:
60
  """
61
+ Lazy-loads the Llava processor + model inside the GPU worker,
62
  runs captioning, and returns a one-sentence description.
63
  """
64
  global processor, vision_model
65
 
66
+ # On first call, instantiate + move to CUDA
67
  if processor is None or vision_model is None:
68
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
69
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
 
72
  low_cpu_mem_usage=True
73
  ).to("cuda")
74
 
 
75
  torch.cuda.empty_cache()
76
  gc.collect()
77
+
78
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
79
  inputs = processor(prompt, image, return_tensors="pt").to("cuda")
80
  output = vision_model.generate(**inputs, max_new_tokens=100)
 
175
  ):
176
  """
177
  1) (Optional) OCR setup
178
+ 2) Vision+Lang model setup & monkey-patch get_image_description
179
+ 3) Extract text & images
180
+ 4) Build and stash vector DB in CURRENT_VDB
181
  """
182
  if not docs:
183
  raise gr.Error("No documents to process")
184
 
185
+ # 1) OCR pipeline if requested
186
  if do_ocr == "Get Text With OCR":
187
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
188
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
189
  else:
190
  local_ocr = None
191
 
192
+ # 2) Vision–language model
193
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
194
  vis = (
195
  LlavaNextForConditionalGeneration
 
197
  .to("cuda")
198
  )
199
 
200
+ # Monkey-patch our pipeline for image captions
201
  def describe(img: Image.Image) -> str:
202
+ torch.cuda.empty_cache()
203
+ gc.collect()
204
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
205
  inputs = proc(prompt, img, return_tensors="pt").to("cuda")
206
  output = vis.generate(**inputs, max_new_tokens=100)
 
209
  global get_image_description, CURRENT_VDB
210
  get_image_description = describe
211
 
212
+ # 3) Extract text + images
213
  progress(0.2, "Extracting text and images…")
214
  all_text = ""
215
  images, names = [], []
216
 
217
  for path in docs:
 
218
  if local_ocr:
219
  pdf = DocumentFile.from_pdf(path)
220
  res = local_ocr(pdf)
 
223
  txt = PdfReader(path).pages[0].extract_text() or ""
224
  all_text += txt + "\n\n"
225
 
 
226
  if include_images == "Include Images":
227
  imgs = extract_images([path])
228
  images.extend(imgs)
229
  names.extend([os.path.basename(path)] * len(imgs))
230
 
231
+ # 4) Build + store the vector DB
232
  progress(0.6, "Indexing in vector DB…")
233
  CURRENT_VDB = get_vectordb(all_text, images, names)
234
 
 
235
  session["processed"] = True
236
  sample_imgs = images[:4] if include_images == "Include Images" else []
237
 
238
+ # ─── return *exactly four* picklable outputs ───
239
  return (
240
+ session, # gr.State: so UI knows we're ready
241
+ all_text[:2000] + "...", # preview text
242
+ sample_imgs, # preview images
243
+ "<h3>Done!</h3>" # Done message
 
244
  )
245
 
246
 
247
+
248
  # Chat function
249
  def conversation(
250
  session: dict,
 
257
  model_id: str
258
  ):
259
  """
260
+ Uses the global CURRENT_VDB (set by extract_data_from_pdfs) to answer.
 
261
  """
262
  global CURRENT_VDB
263
  if not session.get("processed") or CURRENT_VDB is None:
 
270
  huggingfacehub_api_token=HF_TOKEN
271
  )
272
 
273
+ # 1) Text retrieval
274
  text_col = CURRENT_VDB.get_collection("text_db")
275
  docs = text_col.query(
276
  query_texts=[question],
 
278
  include=["documents"]
279
  )["documents"][0]
280
 
281
+ # 2) Image retrieval
282
  img_col = CURRENT_VDB.get_collection("image_db")
283
  img_q = img_col.query(
284
  query_texts=[question],
 
295
  pass
296
  img_desc = "\n".join(img_descs)
297
 
298
+ # 3) Build prompt & call LLM
299
  prompt = PromptTemplate(
300
  template="""
301
  Context:
 
335
 
336
 
337
 
338
+
339
  # ─────────────────────────────────────────────────────────────────────────────
340
  # Gradio UI
341
  CSS = """
 
357
  ]
358
 
359
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
360
+ # We no longer need vdb_state – we keep only session_state
361
  session_state = gr.State({})
362
 
363
  # ─── Welcome Screen ─────────────────────────────────────────────
364
  with gr.Column(visible=True) as welcome_col:
 
365
  gr.Markdown(
366
+ f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>",
367
  elem_id="welcome_md"
368
  )
369
  start_btn = gr.Button("🚀 Start")
 
385
  value="Exclude Images",
386
  label="Images"
387
  )
388
+ ocr_radio = gr.Radio(
389
+ ["Get Text With OCR", "Get Available Text Only"],
390
+ value="Get Available Text Only",
391
+ label="OCR"
392
+ )
393
  ocr_dd = gr.Dropdown(
394
  choices=[
395
  "db_resnet50 + crnn_mobilenet_v3_large",
 
409
  extract_btn = gr.Button("Extract")
410
  preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
411
  preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
412
+ preview_html = gr.HTML() # for the “Done!” message
413
 
414
  extract_btn.click(
415
+ fn=extract_data_from_pdfs,
416
  inputs=[
417
  docs,
418
  session_state,
419
  include_dd,
420
+ ocr_radio,
 
 
 
 
421
  ocr_dd,
422
  vlm_dd
423
  ],
424
  outputs=[
425
+ session_state, # session “processed” flag
426
+ preview_text, # preview text
427
+ preview_img, # preview images
428
+ preview_html # done HTML
 
 
429
  ]
430
  )
431
 
 
445
  value=MODEL_OPTIONS[0],
446
  label="Choose Chat Model"
447
  )
448
+ num_ctx = gr.Slider(1, 20, value=3, label="Text Contexts")
449
+ img_ctx = gr.Slider(1, 10, value=2, label="Image Contexts")
450
+ temp = gr.Slider(0.1, 1.0, step=0.1, value=0.4, label="Temperature")
451
+ max_tok = gr.Slider(10, 1000, step=10, value=200, label="Max Tokens")
452
 
453
  send.click(
454
+ fn=conversation,
455
  inputs=[
456
+ session_state, # now drives conversation
457
  msg,
458
  num_ctx,
459
  img_ctx,
 
464
  ],
465
  outputs=[
466
  chat,
467
+ gr.Dataframe(), # returned docs
468
  gr.Gallery(label="Relevant Images", rows=2, value=[])
469
  ]
470
  )
471
 
 
472
  gr.HTML("<center>Made with ❤️ by Zamal</center>")
473
 
474
  # ─── Wire the Start button ───────────────────────────────────────
475
  start_btn.click(
476
  fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
477
+ inputs=[],
478
+ outputs=[welcome_col, app_col]
479
  )
480
 
481
  if __name__ == "__main__":