Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -58,12 +58,12 @@ CURRENT_VDB = None
|
|
58 |
@spaces.GPU()
|
59 |
def get_image_description(image: Image.Image) -> str:
|
60 |
"""
|
61 |
-
Lazy-loads the Llava processor + model
|
62 |
runs captioning, and returns a one-sentence description.
|
63 |
"""
|
64 |
global processor, vision_model
|
65 |
|
66 |
-
#
|
67 |
if processor is None or vision_model is None:
|
68 |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
69 |
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
|
@@ -72,9 +72,9 @@ def get_image_description(image: Image.Image) -> str:
|
|
72 |
low_cpu_mem_usage=True
|
73 |
).to("cuda")
|
74 |
|
75 |
-
# clear and run
|
76 |
torch.cuda.empty_cache()
|
77 |
gc.collect()
|
|
|
78 |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
|
79 |
inputs = processor(prompt, image, return_tensors="pt").to("cuda")
|
80 |
output = vision_model.generate(**inputs, max_new_tokens=100)
|
@@ -175,21 +175,21 @@ def extract_data_from_pdfs(
|
|
175 |
):
|
176 |
"""
|
177 |
1) (Optional) OCR setup
|
178 |
-
2)
|
179 |
-
3) Extract text
|
180 |
-
4) Build and
|
181 |
"""
|
182 |
if not docs:
|
183 |
raise gr.Error("No documents to process")
|
184 |
|
185 |
-
# 1) OCR
|
186 |
if do_ocr == "Get Text With OCR":
|
187 |
db_m, crnn_m = OCR_CHOICES[ocr_choice]
|
188 |
local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
|
189 |
else:
|
190 |
local_ocr = None
|
191 |
|
192 |
-
# 2) Vision–language model
|
193 |
proc = LlavaNextProcessor.from_pretrained(vlm_choice)
|
194 |
vis = (
|
195 |
LlavaNextForConditionalGeneration
|
@@ -197,9 +197,10 @@ def extract_data_from_pdfs(
|
|
197 |
.to("cuda")
|
198 |
)
|
199 |
|
200 |
-
# Monkey-patch
|
201 |
def describe(img: Image.Image) -> str:
|
202 |
-
torch.cuda.empty_cache()
|
|
|
203 |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
|
204 |
inputs = proc(prompt, img, return_tensors="pt").to("cuda")
|
205 |
output = vis.generate(**inputs, max_new_tokens=100)
|
@@ -208,13 +209,12 @@ def extract_data_from_pdfs(
|
|
208 |
global get_image_description, CURRENT_VDB
|
209 |
get_image_description = describe
|
210 |
|
211 |
-
# 3) Extract text
|
212 |
progress(0.2, "Extracting text and images…")
|
213 |
all_text = ""
|
214 |
images, names = [], []
|
215 |
|
216 |
for path in docs:
|
217 |
-
# text
|
218 |
if local_ocr:
|
219 |
pdf = DocumentFile.from_pdf(path)
|
220 |
res = local_ocr(pdf)
|
@@ -223,29 +223,28 @@ def extract_data_from_pdfs(
|
|
223 |
txt = PdfReader(path).pages[0].extract_text() or ""
|
224 |
all_text += txt + "\n\n"
|
225 |
|
226 |
-
# images
|
227 |
if include_images == "Include Images":
|
228 |
imgs = extract_images([path])
|
229 |
images.extend(imgs)
|
230 |
names.extend([os.path.basename(path)] * len(imgs))
|
231 |
|
232 |
-
# 4) Build
|
233 |
progress(0.6, "Indexing in vector DB…")
|
234 |
CURRENT_VDB = get_vectordb(all_text, images, names)
|
235 |
|
236 |
-
# mark done & return only picklable outputs
|
237 |
session["processed"] = True
|
238 |
sample_imgs = images[:4] if include_images == "Include Images" else []
|
239 |
|
|
|
240 |
return (
|
241 |
-
session,
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
"<h3>Done!</h3>"
|
246 |
)
|
247 |
|
248 |
|
|
|
249 |
# Chat function
|
250 |
def conversation(
|
251 |
session: dict,
|
@@ -258,8 +257,7 @@ def conversation(
|
|
258 |
model_id: str
|
259 |
):
|
260 |
"""
|
261 |
-
|
262 |
-
calls the HF endpoint, and returns updated chat history.
|
263 |
"""
|
264 |
global CURRENT_VDB
|
265 |
if not session.get("processed") or CURRENT_VDB is None:
|
@@ -272,7 +270,7 @@ def conversation(
|
|
272 |
huggingfacehub_api_token=HF_TOKEN
|
273 |
)
|
274 |
|
275 |
-
#
|
276 |
text_col = CURRENT_VDB.get_collection("text_db")
|
277 |
docs = text_col.query(
|
278 |
query_texts=[question],
|
@@ -280,6 +278,7 @@ def conversation(
|
|
280 |
include=["documents"]
|
281 |
)["documents"][0]
|
282 |
|
|
|
283 |
img_col = CURRENT_VDB.get_collection("image_db")
|
284 |
img_q = img_col.query(
|
285 |
query_texts=[question],
|
@@ -296,7 +295,7 @@ def conversation(
|
|
296 |
pass
|
297 |
img_desc = "\n".join(img_descs)
|
298 |
|
299 |
-
# Build
|
300 |
prompt = PromptTemplate(
|
301 |
template="""
|
302 |
Context:
|
@@ -336,6 +335,7 @@ Answer:
|
|
336 |
|
337 |
|
338 |
|
|
|
339 |
# ─────────────────────────────────────────────────────────────────────────────
|
340 |
# Gradio UI
|
341 |
CSS = """
|
@@ -357,14 +357,13 @@ MODEL_OPTIONS = [
|
|
357 |
]
|
358 |
|
359 |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
360 |
-
vdb_state
|
361 |
session_state = gr.State({})
|
362 |
|
363 |
# ─── Welcome Screen ─────────────────────────────────────────────
|
364 |
with gr.Column(visible=True) as welcome_col:
|
365 |
-
|
366 |
gr.Markdown(
|
367 |
-
|
368 |
elem_id="welcome_md"
|
369 |
)
|
370 |
start_btn = gr.Button("🚀 Start")
|
@@ -386,6 +385,11 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
386 |
value="Exclude Images",
|
387 |
label="Images"
|
388 |
)
|
|
|
|
|
|
|
|
|
|
|
389 |
ocr_dd = gr.Dropdown(
|
390 |
choices=[
|
391 |
"db_resnet50 + crnn_mobilenet_v3_large",
|
@@ -405,28 +409,23 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
405 |
extract_btn = gr.Button("Extract")
|
406 |
preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
|
407 |
preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
|
|
|
408 |
|
409 |
extract_btn.click(
|
410 |
-
extract_data_from_pdfs,
|
411 |
inputs=[
|
412 |
docs,
|
413 |
session_state,
|
414 |
include_dd,
|
415 |
-
|
416 |
-
["Get Text With OCR", "Get Available Text Only"],
|
417 |
-
value="Get Available Text Only",
|
418 |
-
label="OCR"
|
419 |
-
),
|
420 |
ocr_dd,
|
421 |
vlm_dd
|
422 |
],
|
423 |
outputs=[
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
preview_img,
|
429 |
-
gr.HTML()
|
430 |
]
|
431 |
)
|
432 |
|
@@ -446,15 +445,15 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
446 |
value=MODEL_OPTIONS[0],
|
447 |
label="Choose Chat Model"
|
448 |
)
|
449 |
-
num_ctx = gr.Slider(1,20,value=3,label="Text Contexts")
|
450 |
-
img_ctx = gr.Slider(1,10,value=2,label="Image Contexts")
|
451 |
-
temp = gr.Slider(0.1,1.0,step=0.1,value=0.4,label="Temperature")
|
452 |
-
max_tok = gr.Slider(10,1000,step=10,value=200,label="Max Tokens")
|
453 |
|
454 |
send.click(
|
455 |
-
conversation,
|
456 |
inputs=[
|
457 |
-
|
458 |
msg,
|
459 |
num_ctx,
|
460 |
img_ctx,
|
@@ -465,18 +464,18 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
465 |
],
|
466 |
outputs=[
|
467 |
chat,
|
468 |
-
gr.Dataframe(),
|
469 |
gr.Gallery(label="Relevant Images", rows=2, value=[])
|
470 |
]
|
471 |
)
|
472 |
|
473 |
-
# Footer inside app_col
|
474 |
gr.HTML("<center>Made with ❤️ by Zamal</center>")
|
475 |
|
476 |
# ─── Wire the Start button ───────────────────────────────────────
|
477 |
start_btn.click(
|
478 |
fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
|
479 |
-
inputs=[],
|
|
|
480 |
)
|
481 |
|
482 |
if __name__ == "__main__":
|
|
|
58 |
@spaces.GPU()
|
59 |
def get_image_description(image: Image.Image) -> str:
|
60 |
"""
|
61 |
+
Lazy-loads the Llava processor + model inside the GPU worker,
|
62 |
runs captioning, and returns a one-sentence description.
|
63 |
"""
|
64 |
global processor, vision_model
|
65 |
|
66 |
+
# On first call, instantiate + move to CUDA
|
67 |
if processor is None or vision_model is None:
|
68 |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
69 |
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
|
|
|
72 |
low_cpu_mem_usage=True
|
73 |
).to("cuda")
|
74 |
|
|
|
75 |
torch.cuda.empty_cache()
|
76 |
gc.collect()
|
77 |
+
|
78 |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
|
79 |
inputs = processor(prompt, image, return_tensors="pt").to("cuda")
|
80 |
output = vision_model.generate(**inputs, max_new_tokens=100)
|
|
|
175 |
):
|
176 |
"""
|
177 |
1) (Optional) OCR setup
|
178 |
+
2) Vision+Lang model setup & monkey-patch get_image_description
|
179 |
+
3) Extract text & images
|
180 |
+
4) Build and stash vector DB in CURRENT_VDB
|
181 |
"""
|
182 |
if not docs:
|
183 |
raise gr.Error("No documents to process")
|
184 |
|
185 |
+
# 1) OCR pipeline if requested
|
186 |
if do_ocr == "Get Text With OCR":
|
187 |
db_m, crnn_m = OCR_CHOICES[ocr_choice]
|
188 |
local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
|
189 |
else:
|
190 |
local_ocr = None
|
191 |
|
192 |
+
# 2) Vision–language model
|
193 |
proc = LlavaNextProcessor.from_pretrained(vlm_choice)
|
194 |
vis = (
|
195 |
LlavaNextForConditionalGeneration
|
|
|
197 |
.to("cuda")
|
198 |
)
|
199 |
|
200 |
+
# Monkey-patch our pipeline for image captions
|
201 |
def describe(img: Image.Image) -> str:
|
202 |
+
torch.cuda.empty_cache()
|
203 |
+
gc.collect()
|
204 |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
|
205 |
inputs = proc(prompt, img, return_tensors="pt").to("cuda")
|
206 |
output = vis.generate(**inputs, max_new_tokens=100)
|
|
|
209 |
global get_image_description, CURRENT_VDB
|
210 |
get_image_description = describe
|
211 |
|
212 |
+
# 3) Extract text + images
|
213 |
progress(0.2, "Extracting text and images…")
|
214 |
all_text = ""
|
215 |
images, names = [], []
|
216 |
|
217 |
for path in docs:
|
|
|
218 |
if local_ocr:
|
219 |
pdf = DocumentFile.from_pdf(path)
|
220 |
res = local_ocr(pdf)
|
|
|
223 |
txt = PdfReader(path).pages[0].extract_text() or ""
|
224 |
all_text += txt + "\n\n"
|
225 |
|
|
|
226 |
if include_images == "Include Images":
|
227 |
imgs = extract_images([path])
|
228 |
images.extend(imgs)
|
229 |
names.extend([os.path.basename(path)] * len(imgs))
|
230 |
|
231 |
+
# 4) Build + store the vector DB
|
232 |
progress(0.6, "Indexing in vector DB…")
|
233 |
CURRENT_VDB = get_vectordb(all_text, images, names)
|
234 |
|
|
|
235 |
session["processed"] = True
|
236 |
sample_imgs = images[:4] if include_images == "Include Images" else []
|
237 |
|
238 |
+
# ─── return *exactly four* picklable outputs ───
|
239 |
return (
|
240 |
+
session, # gr.State: so UI knows we're ready
|
241 |
+
all_text[:2000] + "...", # preview text
|
242 |
+
sample_imgs, # preview images
|
243 |
+
"<h3>Done!</h3>" # Done message
|
|
|
244 |
)
|
245 |
|
246 |
|
247 |
+
|
248 |
# Chat function
|
249 |
def conversation(
|
250 |
session: dict,
|
|
|
257 |
model_id: str
|
258 |
):
|
259 |
"""
|
260 |
+
Uses the global CURRENT_VDB (set by extract_data_from_pdfs) to answer.
|
|
|
261 |
"""
|
262 |
global CURRENT_VDB
|
263 |
if not session.get("processed") or CURRENT_VDB is None:
|
|
|
270 |
huggingfacehub_api_token=HF_TOKEN
|
271 |
)
|
272 |
|
273 |
+
# 1) Text retrieval
|
274 |
text_col = CURRENT_VDB.get_collection("text_db")
|
275 |
docs = text_col.query(
|
276 |
query_texts=[question],
|
|
|
278 |
include=["documents"]
|
279 |
)["documents"][0]
|
280 |
|
281 |
+
# 2) Image retrieval
|
282 |
img_col = CURRENT_VDB.get_collection("image_db")
|
283 |
img_q = img_col.query(
|
284 |
query_texts=[question],
|
|
|
295 |
pass
|
296 |
img_desc = "\n".join(img_descs)
|
297 |
|
298 |
+
# 3) Build prompt & call LLM
|
299 |
prompt = PromptTemplate(
|
300 |
template="""
|
301 |
Context:
|
|
|
335 |
|
336 |
|
337 |
|
338 |
+
|
339 |
# ─────────────────────────────────────────────────────────────────────────────
|
340 |
# Gradio UI
|
341 |
CSS = """
|
|
|
357 |
]
|
358 |
|
359 |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
360 |
+
# We no longer need vdb_state – we keep only session_state
|
361 |
session_state = gr.State({})
|
362 |
|
363 |
# ─── Welcome Screen ─────────────────────────────────────────────
|
364 |
with gr.Column(visible=True) as welcome_col:
|
|
|
365 |
gr.Markdown(
|
366 |
+
f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>",
|
367 |
elem_id="welcome_md"
|
368 |
)
|
369 |
start_btn = gr.Button("🚀 Start")
|
|
|
385 |
value="Exclude Images",
|
386 |
label="Images"
|
387 |
)
|
388 |
+
ocr_radio = gr.Radio(
|
389 |
+
["Get Text With OCR", "Get Available Text Only"],
|
390 |
+
value="Get Available Text Only",
|
391 |
+
label="OCR"
|
392 |
+
)
|
393 |
ocr_dd = gr.Dropdown(
|
394 |
choices=[
|
395 |
"db_resnet50 + crnn_mobilenet_v3_large",
|
|
|
409 |
extract_btn = gr.Button("Extract")
|
410 |
preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
|
411 |
preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
|
412 |
+
preview_html = gr.HTML() # for the “Done!” message
|
413 |
|
414 |
extract_btn.click(
|
415 |
+
fn=extract_data_from_pdfs,
|
416 |
inputs=[
|
417 |
docs,
|
418 |
session_state,
|
419 |
include_dd,
|
420 |
+
ocr_radio,
|
|
|
|
|
|
|
|
|
421 |
ocr_dd,
|
422 |
vlm_dd
|
423 |
],
|
424 |
outputs=[
|
425 |
+
session_state, # session “processed” flag
|
426 |
+
preview_text, # preview text
|
427 |
+
preview_img, # preview images
|
428 |
+
preview_html # done HTML
|
|
|
|
|
429 |
]
|
430 |
)
|
431 |
|
|
|
445 |
value=MODEL_OPTIONS[0],
|
446 |
label="Choose Chat Model"
|
447 |
)
|
448 |
+
num_ctx = gr.Slider(1, 20, value=3, label="Text Contexts")
|
449 |
+
img_ctx = gr.Slider(1, 10, value=2, label="Image Contexts")
|
450 |
+
temp = gr.Slider(0.1, 1.0, step=0.1, value=0.4, label="Temperature")
|
451 |
+
max_tok = gr.Slider(10, 1000, step=10, value=200, label="Max Tokens")
|
452 |
|
453 |
send.click(
|
454 |
+
fn=conversation,
|
455 |
inputs=[
|
456 |
+
session_state, # now drives conversation
|
457 |
msg,
|
458 |
num_ctx,
|
459 |
img_ctx,
|
|
|
464 |
],
|
465 |
outputs=[
|
466 |
chat,
|
467 |
+
gr.Dataframe(), # returned docs
|
468 |
gr.Gallery(label="Relevant Images", rows=2, value=[])
|
469 |
]
|
470 |
)
|
471 |
|
|
|
472 |
gr.HTML("<center>Made with ❤️ by Zamal</center>")
|
473 |
|
474 |
# ─── Wire the Start button ───────────────────────────────────────
|
475 |
start_btn.click(
|
476 |
fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
|
477 |
+
inputs=[],
|
478 |
+
outputs=[welcome_col, app_col]
|
479 |
)
|
480 |
|
481 |
if __name__ == "__main__":
|