zamalali commited on
Commit
15067e5
·
1 Parent(s): 2f03b05

Initial push without .env

Browse files
.gitignore ADDED
Binary file (14 Bytes). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.67 kB). View file
 
__pycache__/welcome_text.cpython-310.pyc ADDED
Binary file (997 Bytes). View file
 
app.py CHANGED
@@ -1,483 +1,452 @@
 
 
1
  import base64
2
- import chromadb
3
  import gc
 
 
 
 
 
 
4
  import gradio as gr
5
- import io
6
  import numpy as np
7
- import os
8
  import pandas as pd
9
  import pymupdf
10
- from pypdf import PdfReader
11
- import spaces
12
- import torch
13
  from PIL import Image
14
- from chromadb.utils import embedding_functions
15
- from chromadb.utils.data_loaders import ImageLoader
 
 
16
  from doctr.io import DocumentFile
17
  from doctr.models import ocr_predictor
18
- from gradio.themes.utils import sizes
19
- from langchain import PromptTemplate
20
- from langchain.text_splitter import RecursiveCharacterTextSplitter
21
- from langchain_community.llms import HuggingFaceEndpoint
22
- from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
23
- from utils import *
24
-
25
 
26
- def result_to_text(result, as_text=False) -> str or list:
27
- full_doc = []
28
- for _, page in enumerate(result.pages, start=1):
29
- text = ""
30
- for block in page.blocks:
31
- text += "\n\t"
32
- for line in block.lines:
33
- for word in line.words:
34
- text += word.value + " "
35
 
36
- full_doc.append(clean_text(text) + "\n\n")
 
 
37
 
38
- return "\n".join(full_doc) if as_text else full_doc
 
39
 
 
 
 
 
40
 
 
41
  ocr_model = ocr_predictor(
42
- "db_resnet50",
43
- "crnn_mobilenet_v3_large",
44
- pretrained=True,
45
- assume_straight_pages=True,
46
  )
47
-
48
-
49
-
50
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
51
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
52
  "llava-hf/llava-v1.6-mistral-7b-hf",
53
  torch_dtype=torch.float16,
54
- low_cpu_mem_usage=True,
55
- load_in_4bit=True,
56
- )
57
- vision_model.to("cuda:0")
58
 
59
 
60
- @spaces.GPU
61
- def get_image_description(image):
 
62
  torch.cuda.empty_cache()
63
  gc.collect()
64
-
65
- # n = len(prompt)
66
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
67
-
68
- inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
69
  output = vision_model.generate(**inputs, max_new_tokens=100)
70
- return (processor.decode(output[0], skip_special_tokens=True))
71
 
 
 
 
 
 
 
 
72
 
73
- CSS = """
74
- #table_col {background-color: rgb(3, 100, 4);}
75
- footer {visibility: hidden;}
76
- """
77
-
78
 
79
- # def get_vectordb(text, images, tables):
80
- def get_vectordb(text, images, img_doc_files):
 
 
 
 
 
 
81
  client = chromadb.EphemeralClient()
82
- loader = ImageLoader()
83
- sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
84
- model_name="multi-qa-mpnet-base-dot-v1"
85
- )
86
- if "text_db" in [i.name for i in client.list_collections()]:
87
- client.delete_collection("text_db")
88
- if "image_db" in [i.name for i in client.list_collections()]:
89
- client.delete_collection("image_db")
90
 
91
- text_collection = client.get_or_create_collection(
 
92
  name="text_db",
93
- embedding_function=sentence_transformer_ef,
94
- data_loader=loader,
95
  )
96
- image_collection = client.get_or_create_collection(
97
  name="image_db",
98
- embedding_function=sentence_transformer_ef,
99
- data_loader=loader,
100
  metadata={"hnsw:space": "cosine"},
 
101
  )
102
- descs = []
103
- for i in range(len(images)):
104
- try:
105
- descs.append(img_doc_files[i]+"\n"+get_image_description(images[i]))
106
- except:
107
- descs.append("Could not generate image description due to some error")
108
- gr.Error("Could not generate image descriptions. Your GPU limit may have been exhausted. Please try again after an hour.")
109
- print(descs[-1])
110
- print()
111
-
112
- # image_descriptions = get_image_descriptions(images)
113
- image_dict = [{"image": image_to_bytes(img)} for img in images]
114
 
115
- if len(images) > 0:
116
- image_collection.add(
 
 
 
 
 
 
 
 
 
 
 
 
117
  ids=[str(i) for i in range(len(images))],
118
  documents=descs,
119
- metadatas=image_dict,
120
  )
121
 
122
- splitter = RecursiveCharacterTextSplitter(
123
- chunk_size=500,
124
- chunk_overlap=10,
 
 
 
125
  )
126
 
127
- if len(text.replace(" ", "").replace("\n", "")) == 0:
128
- gr.Error("No text found in documents")
129
- else:
130
- docs = splitter.create_documents([text])
131
- doc_texts = [i.page_content for i in docs]
132
- text_collection.add(
133
- ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts
134
- )
135
  return client
136
 
137
 
138
- def extract_only_text(reader):
139
- text = ""
140
- for _, page in enumerate(reader.pages):
141
- text = page.extract_text()
142
- return text.strip()
143
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def extract_data_from_pdfs(
146
- docs, session, include_images, do_ocr, progress=gr.Progress()
 
 
 
 
 
 
147
  ):
148
- if len(docs) == 0:
 
 
 
 
 
 
149
  raise gr.Error("No documents to process")
150
- progress(0, "Extracting Images")
151
-
152
- # images = extract_images(docs)
153
-
154
- progress(0.25, "Extracting Text")
155
-
156
- all_text = ""
157
 
158
- images = []
159
- img_docs=[]
160
- for doc in docs:
161
- if do_ocr == "Get Text With OCR":
162
- pdf_doc = DocumentFile.from_pdf(doc)
163
- result = ocr_model(pdf_doc)
164
- all_text += result_to_text(result, as_text=True) + "\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  else:
166
- reader = PdfReader(doc)
167
- all_text += extract_only_text(reader) + "\n\n"
168
 
169
  if include_images == "Include Images":
170
- imgs = extract_images([doc])
171
  images.extend(imgs)
172
- img_docs.extend([doc.split("/")[-1] for _ in range(len(imgs))])
173
 
174
- progress(
175
- 0.6, "Generating image descriptions and inserting everything into vectorDB"
176
- )
177
- vectordb = get_vectordb(all_text, images, img_docs)
178
 
179
- progress(1, "Completed")
180
  session["processed"] = True
 
181
  return (
182
- vectordb,
183
  session,
184
  gr.Row(visible=True),
185
  all_text[:2000] + "...",
186
- # display,
187
- images[:2],
188
- "<h1 style='text-align: center'>Completed<h1>",
189
- # image_descriptions
190
  )
191
-
192
-
193
- sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
194
- model_name="multi-qa-mpnet-base-dot-v1"
195
- )
196
-
197
-
198
  def conversation(
199
- vectordb_client,
200
- msg,
201
- num_context,
202
- img_context,
203
- history,
204
- temperature,
205
- max_new_tokens,
206
- hf_token,
207
- model_path,
208
  ):
209
- if hf_token.strip() != "" and model_path.strip() != "":
210
- llm = HuggingFaceEndpoint(
211
- repo_id=model_path,
212
- temperature=temperature,
213
- max_new_tokens=max_new_tokens,
214
- huggingfacehub_api_token=hf_token,
215
- )
216
- else:
217
- llm = HuggingFaceEndpoint(
218
- repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
219
- temperature=temperature,
220
- max_new_tokens=max_new_tokens,
221
- huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
222
- )
223
-
224
- text_collection = vectordb_client.get_collection(
225
- "text_db", embedding_function=sentence_transformer_ef
226
- )
227
- image_collection = vectordb_client.get_collection(
228
- "image_db", embedding_function=sentence_transformer_ef
229
  )
230
 
231
- results = text_collection.query(
232
- query_texts=[msg], include=["documents"], n_results=num_context
 
 
 
 
233
  )["documents"][0]
234
- similar_images = image_collection.query(
235
- query_texts=[msg],
236
- include=["metadatas", "distances", "documents"],
237
- n_results=img_context,
238
- )
239
- img_links = [i["image"] for i in similar_images["metadatas"][0]]
240
 
241
- images_and_locs = [
242
- Image.open(io.BytesIO(base64.b64decode(i[1])))
243
- for i in zip(similar_images["distances"][0], img_links)
244
- ]
245
- img_desc = "\n".join(similar_images["documents"][0])
246
- if len(img_links) == 0:
247
- img_desc = "No Images Are Provided"
248
- template = """
249
- Context:
250
- {context}
 
 
 
 
 
 
251
 
252
- Included Images:
253
- {images}
254
-
255
- Question:
256
- {question}
257
 
258
- Answer:
 
259
 
260
- """
261
- prompt = PromptTemplate(template=template, input_variables=["context", "question"])
262
- context = "\n\n".join(results)
263
- # references = [gr.Textbox(i, visible=True, interactive=False) for i in results]
264
- response = llm(prompt.format(context=context, question=msg, images=img_desc))
265
- return history + [(msg, response)], results, images_and_locs
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
- def check_validity_and_llm(session_states):
269
- if session_states.get("processed", False) == True:
270
- return gr.Tabs(selected=2)
271
- raise gr.Error("Please extract data first")
272
 
273
 
274
 
275
- with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
276
- vectordb = gr.State()
277
- doc_collection = gr.State(value=[])
278
- session_states = gr.State(value={})
279
- references = gr.State(value=[])
280
 
281
- gr.Markdown(
282
- """<h2><center>Chat PDF Multimodal 💬 </center></h2>
283
- <h3><center><b>Interact With Your PDF Documents</b></center></h3>"""
284
- )
285
- gr.Markdown(
286
- """<center><h3><b>Note: </b> This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents</center><h3><br>
287
- <center>Utilizing multimodal capabilities, this chatbot can understand and answer queries based on both textual and visual information within your PDFs.</center>"""
288
- )
289
- gr.Markdown(
290
- """
291
- <center><b>Warning: </b> Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description<center>
292
- """
293
- )
294
- with gr.Tabs() as tabs:
295
- with gr.TabItem("Upload PDFs", id=0) as pdf_tab:
296
- with gr.Row():
297
- with gr.Column():
298
- documents = gr.File(
299
- file_count="multiple",
300
- file_types=["pdf"],
301
- interactive=True,
302
- label="Upload your PDF file/s",
303
- )
304
- pdf_btn = gr.Button(value="Next", elem_id="button1")
305
-
306
- with gr.TabItem("Extract Data", id=1) as preprocess:
307
- with gr.Row():
308
- with gr.Column():
309
- back_p1 = gr.Button(value="Back")
310
- with gr.Column():
311
- embed = gr.Button(value="Extract Data")
312
- with gr.Column():
313
- next_p1 = gr.Button(value="Next")
314
- with gr.Row():
315
- include_images = gr.Radio(
 
 
 
 
 
316
  ["Include Images", "Exclude Images"],
317
- value="Include Images",
318
- label="Include/ Exclude Images",
319
- interactive=True,
320
  )
321
- do_ocr = gr.Radio(
322
- ["Get Text With OCR", "Get Available Text Only"],
323
- value="Get Text With OCR",
324
- label="OCR/ No OCR",
325
- interactive=True,
 
 
326
  )
327
-
328
- with gr.Row(equal_height=True, variant="panel") as row:
329
- selected = gr.Dataframe(
330
- interactive=False,
331
- col_count=(1, "fixed"),
332
- headers=["Selected Files"],
 
333
  )
334
- prog = gr.HTML(
335
- value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  )
337
 
338
- with gr.Accordion("See Parts of Extracted Data", open=False):
339
- with gr.Column(visible=True) as sample_data:
340
- with gr.Row():
341
- with gr.Column():
342
- ext_text = gr.Textbox(
343
- label="Sample Extracted Text", lines=15
344
- )
345
- with gr.Column():
346
- images = gr.Gallery(
347
- label="Sample Extracted Images", columns=1, rows=2
348
- )
349
-
350
- with gr.TabItem("Chat", id=2) as chat_tab:
351
- with gr.Accordion("Config (Advanced) (Optional)", open=False):
352
- with gr.Row(variant="panel", equal_height=True):
353
- choice = gr.Radio(
354
- ["chromaDB"],
355
- value="chromaDB",
356
- label="Vector Database",
357
- interactive=True,
358
- )
359
- with gr.Accordion("Use your own model (optional)", open=False):
360
- hf_token = gr.Textbox(
361
- label="HuggingFace Token", interactive=True
362
  )
363
- model_path = gr.Textbox(label="Model Path", interactive=True)
364
- with gr.Row(variant="panel", equal_height=True):
365
- num_context = gr.Slider(
366
- label="Number of text context elements",
367
- minimum=1,
368
- maximum=20,
369
- step=1,
370
- interactive=True,
371
- value=3,
372
- )
373
- img_context = gr.Slider(
374
- label="Number of image context elements",
375
- minimum=1,
376
- maximum=10,
377
- step=1,
378
- interactive=True,
379
- value=2,
380
- )
381
- with gr.Row(variant="panel", equal_height=True):
382
- temp = gr.Slider(
383
- label="Temperature",
384
- minimum=0.1,
385
- maximum=1,
386
- step=0.1,
387
- interactive=True,
388
- value=0.4,
389
- )
390
- max_tokens = gr.Slider(
391
- label="Max Tokens",
392
- minimum=10,
393
- maximum=2000,
394
- step=10,
395
- interactive=True,
396
- value=500,
397
- )
398
- with gr.Row():
399
- with gr.Column():
400
- ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
401
- with gr.Column():
402
- chatbot = gr.Chatbot(height=400)
403
- with gr.Accordion("Text References", open=False):
404
- # text_context = gr.Row()
405
-
406
- @gr.render(inputs=references)
407
- def gen_refs(references):
408
- # print(references)
409
- n = len(references)
410
- for i in range(n):
411
- gr.Textbox(
412
- label=f"Reference-{i+1}", value=references[i], lines=3
413
  )
414
-
415
- with gr.Row():
416
- msg = gr.Textbox(
417
- placeholder="Type your question here (e.g. 'What is this document about?')",
418
- interactive=True,
419
- container=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  )
421
- with gr.Row():
422
- submit_btn = gr.Button("Submit message")
423
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
424
-
425
- pdf_btn.click(
426
- fn=extract_pdfs,
427
- inputs=[documents, doc_collection],
428
- outputs=[doc_collection, tabs, selected],
429
- )
430
- embed.click(
431
- extract_data_from_pdfs,
432
- inputs=[doc_collection, session_states, include_images, do_ocr],
433
- outputs=[
434
- vectordb,
435
- session_states,
436
- sample_data,
437
- ext_text,
438
- images,
439
- prog,
440
- ],
441
- )
442
 
443
- submit_btn.click(
444
- conversation,
445
- [
446
- vectordb,
447
- msg,
448
- num_context,
449
- img_context,
450
- chatbot,
451
- temp,
452
- max_tokens,
453
- hf_token,
454
- model_path,
455
- ],
456
- [chatbot, references, ret_images],
457
- )
458
- msg.submit(
459
- conversation,
460
- [
461
- vectordb,
462
- msg,
463
- num_context,
464
- img_context,
465
- chatbot,
466
- temp,
467
- max_tokens,
468
- hf_token,
469
- model_path,
470
- ],
471
- [chatbot, references, ret_images],
472
- )
473
-
474
- documents.change(lambda: "<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>", None, prog)
475
 
476
- back_p1.click(lambda: gr.Tabs(selected=0), None, tabs)
477
-
478
- next_p1.click(check_validity_and_llm, session_states, tabs)
479
- gr.HTML('<div style="text-align: center; margin-top: 20px;">Made with ❤️</div>')
 
480
 
481
  if __name__ == "__main__":
482
  demo.launch()
483
-
 
1
+ import os
2
+ import io
3
  import base64
 
4
  import gc
5
+ from huggingface_hub.utils import HfHubHTTPError
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain_huggingface import HuggingFaceEndpoint
8
+ import io, base64
9
+ from PIL import Image
10
+ import torch
11
  import gradio as gr
12
+ import spaces
13
  import numpy as np
 
14
  import pandas as pd
15
  import pymupdf
 
 
 
16
  from PIL import Image
17
+ from pypdf import PdfReader
18
+ from dotenv import load_dotenv
19
+ from welcome_text import WELCOME_INTRO
20
+
21
  from doctr.io import DocumentFile
22
  from doctr.models import ocr_predictor
23
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
 
 
 
 
 
 
24
 
25
+ import chromadb
26
+ from chromadb.utils import embedding_functions
27
+ from chromadb.utils.data_loaders import ImageLoader
 
 
 
 
 
 
28
 
29
+ from langchain_core.prompts import PromptTemplate
30
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
31
+ from langchain_huggingface import HuggingFaceEndpoint
32
 
33
+ from utils import extract_pdfs, extract_images, clean_text, image_to_bytes
34
+ from utils import *
35
 
36
+ # ─────────────────────────────────────────────────────────────────────────────
37
+ # Load .env
38
+ load_dotenv()
39
+ HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
40
 
41
+ # OCR + multimodal image description setup
42
  ocr_model = ocr_predictor(
43
+ "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
 
 
 
44
  )
 
 
 
45
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
46
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
47
  "llava-hf/llava-v1.6-mistral-7b-hf",
48
  torch_dtype=torch.float16,
49
+ low_cpu_mem_usage=True
50
+ ).to("cuda")
 
 
51
 
52
 
53
+ @spaces.GPU()
54
+ def get_image_description(image: Image.Image) -> str:
55
+ """Generate a one-sentence description via LlavaNext."""
56
  torch.cuda.empty_cache()
57
  gc.collect()
 
 
58
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
59
+ inputs = processor(prompt, image, return_tensors="pt").to("cuda")
 
60
  output = vision_model.generate(**inputs, max_new_tokens=100)
61
+ return processor.decode(output[0], skip_special_tokens=True)
62
 
63
+ # Vector DB setup
64
+ # at top of file, alongside your other imports
65
+ from chromadb.utils import embedding_functions
66
+ from chromadb.utils.data_loaders import ImageLoader
67
+ import chromadb
68
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
69
+ from utils import image_to_bytes # your helper
70
 
71
+ # 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim)
72
+ SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
73
+ model_name="all-MiniLM-L6-v2"
74
+ )
 
75
 
76
+ def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
77
+ """
78
+ Build an in-memory ChromaDB instance with two collections:
79
+ • text_db (chunks of the PDF text)
80
+ • image_db (image descriptions + raw image bytes)
81
+ Returns the Chroma client for later querying.
82
+ """
83
+ # ——— 1) Init & wipe old ————————————————
84
  client = chromadb.EphemeralClient()
85
+ for col in ("text_db", "image_db"):
86
+ if col in [c.name for c in client.list_collections()]:
87
+ client.delete_collection(col)
 
 
 
 
 
88
 
89
+ # ——— 2) Create fresh collections —————————
90
+ text_col = client.get_or_create_collection(
91
  name="text_db",
92
+ embedding_function=SHARED_EMB_FN,
93
+ data_loader=ImageLoader(), # loader only matters for images, benign here
94
  )
95
+ img_col = client.get_or_create_collection(
96
  name="image_db",
97
+ embedding_function=SHARED_EMB_FN,
 
98
  metadata={"hnsw:space": "cosine"},
99
+ data_loader=ImageLoader(),
100
  )
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # ——— 3) Add images if any ———————————————
103
+ if images:
104
+ descs = []
105
+ metas = []
106
+ for idx, img in enumerate(images):
107
+ # build one-line caption (or fallback)
108
+ try:
109
+ caption = get_image_description(img)
110
+ except Exception:
111
+ caption = "⚠️ could not describe image"
112
+ descs.append(f"{img_names[idx]}: {caption}")
113
+ metas.append({"image": image_to_bytes(img)})
114
+
115
+ img_col.add(
116
  ids=[str(i) for i in range(len(images))],
117
  documents=descs,
118
+ metadatas=metas,
119
  )
120
 
121
+ # ——— 4) Chunk & add text ———————————————
122
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
123
+ docs = splitter.create_documents([text])
124
+ text_col.add(
125
+ ids=[str(i) for i in range(len(docs))],
126
+ documents=[d.page_content for d in docs],
127
  )
128
 
 
 
 
 
 
 
 
 
129
  return client
130
 
131
 
 
 
 
 
 
132
 
133
+ # Text extraction
134
+ def result_to_text(result, as_text=False):
135
+ pages = []
136
+ for pg in result.pages:
137
+ txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words)
138
+ pages.append(clean_text(txt))
139
+ return "\n\n".join(pages) if as_text else pages
140
+
141
+ OCR_CHOICES = {
142
+ "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"),
143
+ "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
144
+ }
145
 
146
  def extract_data_from_pdfs(
147
+ docs,
148
+ session,
149
+ include_images, # "Include Images" or "Exclude Images"
150
+ do_ocr, # "Get Text With OCR" or "Get Available Text Only"
151
+ ocr_choice, # key into OCR_CHOICES
152
+ vlm_choice, # HF repo ID for LlavaNext
153
+ progress=gr.Progress()
154
  ):
155
+ """
156
+ 1) Dynamically instantiate the chosen OCR pipeline (if any)
157
+ 2) Dynamically instantiate the chosen vision‐language model
158
+ 3) Override the global get_image_description to use that model for captions
159
+ 4) Extract text & images, index into ChromaDB
160
+ """
161
+ if not docs:
162
  raise gr.Error("No documents to process")
 
 
 
 
 
 
 
163
 
164
+ # ——— 1) Set up OCR if requested ————————————————
165
+ if do_ocr == "Get Text With OCR":
166
+ db_m, crnn_m = OCR_CHOICES[ocr_choice]
167
+ local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
168
+ else:
169
+ local_ocr = None
170
+
171
+ # ——— 2) Set up vision‐language model —————————————
172
+ proc = LlavaNextProcessor.from_pretrained(vlm_choice)
173
+ vis = LlavaNextForConditionalGeneration.from_pretrained(
174
+ vlm_choice,
175
+ torch_dtype=torch.float16,
176
+ low_cpu_mem_usage=True
177
+ ).to("cuda")
178
+
179
+ # ——— 3) Monkey‐patch global get_image_description ————
180
+ def describe(img: Image.Image) -> str:
181
+ torch.cuda.empty_cache(); gc.collect()
182
+ prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
183
+ inputs = proc(prompt, img, return_tensors="pt").to("cuda")
184
+ output = vis.generate(**inputs, max_new_tokens=100)
185
+ return proc.decode(output[0], skip_special_tokens=True)
186
+
187
+ global get_image_description
188
+ get_image_description = describe
189
+
190
+ # ——— 4) Extract text & images ————————————————
191
+ progress(0.2, "Extracting text and images…")
192
+ all_text, images, names = "", [], []
193
+ for path in docs:
194
+ if local_ocr:
195
+ pdf = DocumentFile.from_pdf(path)
196
+ res = local_ocr(pdf)
197
+ all_text += result_to_text(res, as_text=True) + "\n\n"
198
  else:
199
+ txt = PdfReader(path).pages[0].extract_text() or ""
200
+ all_text += "\n\n" + txt + "\n\n"
201
 
202
  if include_images == "Include Images":
203
+ imgs = extract_images([path])
204
  images.extend(imgs)
205
+ names.extend([os.path.basename(path)] * len(imgs))
206
 
207
+ # ——— 5) Index into vector DB ————————————————
208
+ progress(0.6, "Indexing in vector DB…")
209
+ vdb = get_vectordb(all_text, images, names)
 
210
 
 
211
  session["processed"] = True
212
+ sample_imgs = images[:4] if include_images == "Include Images" else []
213
  return (
214
+ vdb,
215
  session,
216
  gr.Row(visible=True),
217
  all_text[:2000] + "...",
218
+ sample_imgs,
219
+ "<h3>Done!</h3>"
 
 
220
  )
221
+ # Chat function
 
 
 
 
 
 
222
  def conversation(
223
+ vdb, question: str, num_ctx, img_ctx,
224
+ history: list, temp: float, max_tok: int, model_id: str
 
 
 
 
 
 
 
225
  ):
226
+ # 0) Cast the context sliders to ints
227
+ num_ctx = int(num_ctx)
228
+ img_ctx = int(img_ctx)
229
+
230
+ # 1) Guard: must have extracted first
231
+ if vdb is None:
232
+ raise gr.Error("Please extract data first")
233
+
234
+ # 2) Instantiate the chosen HF endpoint
235
+ llm = HuggingFaceEndpoint(
236
+ repo_id=model_id,
237
+ temperature=temp,
238
+ max_new_tokens=max_tok,
239
+ huggingfacehub_api_token=HF_TOKEN
 
 
 
 
 
 
240
  )
241
 
242
+ # 3) Query text collection
243
+ text_col = vdb.get_collection("text_db")
244
+ docs = text_col.query(
245
+ query_texts=[question],
246
+ n_results=num_ctx, # now an int
247
+ include=["documents"]
248
  )["documents"][0]
 
 
 
 
 
 
249
 
250
+ # 4) Query image collection
251
+ img_col = vdb.get_collection("image_db")
252
+ img_q = img_col.query(
253
+ query_texts=[question],
254
+ n_results=img_ctx, # now an int
255
+ include=["metadatas", "documents"]
256
+ )
257
+ # rest unchanged …
258
+ images, img_descs = [], img_q["documents"][0] or ["No images found"]
259
+ for meta in img_q["metadatas"][0]:
260
+ b64 = meta.get("image", "")
261
+ try:
262
+ images.append(Image.open(io.BytesIO(base64.b64decode(b64))))
263
+ except:
264
+ pass
265
+ img_desc = "\n".join(img_descs)
266
 
267
+ # 5) Build prompt
268
+ prompt = PromptTemplate(
269
+ template="""
270
+ Context:
271
+ {text}
272
 
273
+ Included Images:
274
+ {img_desc}
275
 
276
+ Question:
277
+ {q}
 
 
 
 
278
 
279
+ Answer:
280
+ """,
281
+ input_variables=["text", "img_desc", "q"],
282
+ )
283
+ context = "\n\n".join(docs)
284
+ user_input = prompt.format(text=context, img_desc=img_desc, q=question)
285
+
286
+ # 6) Call the model with error handling
287
+ try:
288
+ answer = llm.invoke(user_input)
289
+ except HfHubHTTPError as e:
290
+ if e.response.status_code == 404:
291
+ answer = f"❌ Model `{model_id}` not hosted on HF Inference API."
292
+ else:
293
+ answer = f"⚠️ HF API error: {e}"
294
+ except Exception as e:
295
+ answer = f"⚠️ Unexpected error: {e}"
296
+
297
+ # 7) Append to history
298
+ new_history = history + [
299
+ {"role":"user", "content": question},
300
+ {"role":"assistant","content": answer}
301
+ ]
302
 
303
+ # 8) Return updated history, docs, images
304
+ return new_history, docs, images
 
 
305
 
306
 
307
 
308
+ # ─────────────────────────────────────────────────────────────────────────────
309
+ # Gradio UI
310
+ CSS = """
311
+ footer {visibility:hidden;}
312
+ """
313
 
314
+ MODEL_OPTIONS = [
315
+ "HuggingFaceH4/zephyr-7b-beta",
316
+ "mistralai/Mistral-7B-Instruct-v0.2",
317
+ "openchat/openchat-3.5-0106",
318
+ "google/gemma-7b-it",
319
+ "deepseek-ai/deepseek-llm-7b-chat",
320
+ "microsoft/Phi-3-mini-4k-instruct",
321
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
322
+ "Qwen/Qwen1.5-7B-Chat",
323
+ "tiiuae/falcon-7b-instruct", # Falcon 7B Instruct
324
+ "bigscience/bloomz-7b1", # BLOOMZ 7B
325
+ "facebook/opt-2.7b",
326
+ ]
327
+
328
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
329
+ vdb_state = gr.State()
330
+ session_state = gr.State({})
331
+
332
+ # ─── Welcome Screen ─────────────────────────────────────────────
333
+ with gr.Column(visible=True) as welcome_col:
334
+
335
+ gr.Markdown(
336
+ f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>",
337
+ elem_id="welcome_md"
338
+ )
339
+ start_btn = gr.Button("🚀 Start")
340
+
341
+ # ─── Main App (hidden until Start is clicked) ───────────────────
342
+ with gr.Column(visible=False) as app_col:
343
+ gr.Markdown("## 📚 Multimodal Chat-PDF Playground")
344
+
345
+ with gr.Tabs():
346
+ # Tab 1: Upload & Extract
347
+ with gr.TabItem("1. Upload & Extract"):
348
+ docs = gr.File(
349
+ file_count="multiple",
350
+ file_types=[".pdf"],
351
+ label="Upload PDFs"
352
+ )
353
+ include_dd = gr.Radio(
354
  ["Include Images", "Exclude Images"],
355
+ value="Exclude Images",
356
+ label="Images"
 
357
  )
358
+ ocr_dd = gr.Dropdown(
359
+ choices=[
360
+ "db_resnet50 + crnn_mobilenet_v3_large",
361
+ "db_resnet50 + crnn_resnet31"
362
+ ],
363
+ value="db_resnet50 + crnn_mobilenet_v3_large",
364
+ label="OCR Model"
365
  )
366
+ vlm_dd = gr.Dropdown(
367
+ choices=[
368
+ "llava-hf/llava-v1.6-mistral-7b-hf",
369
+ "llava-hf/llava-v1.5-mistral-7b"
370
+ ],
371
+ value="llava-hf/llava-v1.6-mistral-7b-hf",
372
+ label="Vision-Language Model"
373
  )
374
+ extract_btn = gr.Button("Extract")
375
+ preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
376
+ preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
377
+
378
+ extract_btn.click(
379
+ extract_data_from_pdfs,
380
+ inputs=[
381
+ docs,
382
+ session_state,
383
+ include_dd,
384
+ gr.Radio(
385
+ ["Get Text With OCR", "Get Available Text Only"],
386
+ value="Get Available Text Only",
387
+ label="OCR"
388
+ ),
389
+ ocr_dd,
390
+ vlm_dd
391
+ ],
392
+ outputs=[
393
+ vdb_state,
394
+ session_state,
395
+ gr.Row(visible=False),
396
+ preview_text,
397
+ preview_img,
398
+ gr.HTML()
399
+ ]
400
  )
401
 
402
+ # Tab 2: Chat
403
+ with gr.TabItem("2. Chat"):
404
+ with gr.Row():
405
+ with gr.Column(scale=3):
406
+ chat = gr.Chatbot(type="messages", label="Chat")
407
+ msg = gr.Textbox(
408
+ placeholder="Ask about your PDF...",
409
+ label="Your question"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  )
411
+ send = gr.Button("Send")
412
+ with gr.Column(scale=1):
413
+ model_dd = gr.Dropdown(
414
+ MODEL_OPTIONS,
415
+ value=MODEL_OPTIONS[0],
416
+ label="Choose Chat Model"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  )
418
+ num_ctx = gr.Slider(1,20,value=3,label="Text Contexts")
419
+ img_ctx = gr.Slider(1,10,value=2,label="Image Contexts")
420
+ temp = gr.Slider(0.1,1.0,step=0.1,value=0.4,label="Temperature")
421
+ max_tok = gr.Slider(10,1000,step=10,value=200,label="Max Tokens")
422
+
423
+ send.click(
424
+ conversation,
425
+ inputs=[
426
+ vdb_state,
427
+ msg,
428
+ num_ctx,
429
+ img_ctx,
430
+ chat,
431
+ temp,
432
+ max_tok,
433
+ model_dd
434
+ ],
435
+ outputs=[
436
+ chat,
437
+ gr.Dataframe(),
438
+ gr.Gallery(label="Relevant Images", rows=2, value=[])
439
+ ]
440
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
+ # Footer inside app_col
443
+ gr.HTML("<center>Made with ❤️ by Zamal</center>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
+ # ─── Wire the Start button ───────────────────────────────────────
446
+ start_btn.click(
447
+ fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
448
+ inputs=[], outputs=[welcome_col, app_col]
449
+ )
450
 
451
  if __name__ == "__main__":
452
  demo.launch()
 
test.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import gc
5
+ from huggingface_hub.utils import HfHubHTTPError
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain_huggingface import HuggingFaceEndpoint
8
+ import io, base64
9
+ from PIL import Image
10
+ import gradio as gr
11
+ import torch
12
+ import gradio as gr
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pymupdf
16
+ from PIL import Image
17
+ from pypdf import PdfReader
18
+ from dotenv import load_dotenv
19
+ from welcome_text import WELCOME_INTRO
20
+
21
+ from doctr.io import DocumentFile
22
+ from doctr.models import ocr_predictor
23
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
24
+
25
+ import chromadb
26
+ from chromadb.utils import embedding_functions
27
+ from chromadb.utils.data_loaders import ImageLoader
28
+
29
+ from langchain_core.prompts import PromptTemplate
30
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
31
+ from langchain_huggingface import HuggingFaceEndpoint
32
+
33
+ from utils import extract_pdfs, extract_images, clean_text, image_to_bytes
34
+ from utils import *
35
+
36
+ # ─────────────────────────────────────────────────────────────────────────────
37
+ # Load .env
38
+ load_dotenv()
39
+ HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
40
+
41
+ # OCR + multimodal image description setup
42
+ ocr_model = ocr_predictor(
43
+ "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
44
+ )
45
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
46
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
47
+ "llava-hf/llava-v1.6-mistral-7b-hf",
48
+ torch_dtype=torch.float16,
49
+ low_cpu_mem_usage=True
50
+ ).to("cpu")
51
+
52
+
53
+
54
+ def get_image_description(image: Image.Image) -> str:
55
+ """Generate a one-sentence description via LlavaNext."""
56
+ torch.cuda.empty_cache()
57
+ gc.collect()
58
+ prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
59
+ inputs = processor(prompt, image, return_tensors="pt").to("cpu")
60
+ output = vision_model.generate(**inputs, max_new_tokens=100)
61
+ return processor.decode(output[0], skip_special_tokens=True)
62
+
63
+ # Vector DB setup
64
+ # at top of file, alongside your other imports
65
+ from chromadb.utils import embedding_functions
66
+ from chromadb.utils.data_loaders import ImageLoader
67
+ import chromadb
68
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
69
+ from utils import image_to_bytes # your helper
70
+
71
+ # 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim)
72
+ SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
73
+ model_name="all-MiniLM-L6-v2"
74
+ )
75
+
76
+ def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
77
+ """
78
+ Build an in-memory ChromaDB instance with two collections:
79
+ • text_db (chunks of the PDF text)
80
+ • image_db (image descriptions + raw image bytes)
81
+ Returns the Chroma client for later querying.
82
+ """
83
+ # ——— 1) Init & wipe old ————————————————
84
+ client = chromadb.EphemeralClient()
85
+ for col in ("text_db", "image_db"):
86
+ if col in [c.name for c in client.list_collections()]:
87
+ client.delete_collection(col)
88
+
89
+ # ——— 2) Create fresh collections —————————
90
+ text_col = client.get_or_create_collection(
91
+ name="text_db",
92
+ embedding_function=SHARED_EMB_FN,
93
+ data_loader=ImageLoader(), # loader only matters for images, benign here
94
+ )
95
+ img_col = client.get_or_create_collection(
96
+ name="image_db",
97
+ embedding_function=SHARED_EMB_FN,
98
+ metadata={"hnsw:space": "cosine"},
99
+ data_loader=ImageLoader(),
100
+ )
101
+
102
+ # ——— 3) Add images if any ———————————————
103
+ if images:
104
+ descs = []
105
+ metas = []
106
+ for idx, img in enumerate(images):
107
+ # build one-line caption (or fallback)
108
+ try:
109
+ caption = get_image_description(img)
110
+ except Exception:
111
+ caption = "⚠️ could not describe image"
112
+ descs.append(f"{img_names[idx]}: {caption}")
113
+ metas.append({"image": image_to_bytes(img)})
114
+
115
+ img_col.add(
116
+ ids=[str(i) for i in range(len(images))],
117
+ documents=descs,
118
+ metadatas=metas,
119
+ )
120
+
121
+ # ——— 4) Chunk & add text ———————————————
122
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
123
+ docs = splitter.create_documents([text])
124
+ text_col.add(
125
+ ids=[str(i) for i in range(len(docs))],
126
+ documents=[d.page_content for d in docs],
127
+ )
128
+
129
+ return client
130
+
131
+
132
+
133
+ # Text extraction
134
+ def result_to_text(result, as_text=False):
135
+ pages = []
136
+ for pg in result.pages:
137
+ txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words)
138
+ pages.append(clean_text(txt))
139
+ return "\n\n".join(pages) if as_text else pages
140
+
141
+ OCR_CHOICES = {
142
+ "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"),
143
+ "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
144
+ }
145
+
146
+ def extract_data_from_pdfs(
147
+ docs,
148
+ session,
149
+ include_images, # "Include Images" or "Exclude Images"
150
+ do_ocr, # "Get Text With OCR" or "Get Available Text Only"
151
+ ocr_choice, # key into OCR_CHOICES
152
+ vlm_choice, # HF repo ID for LlavaNext
153
+ progress=gr.Progress()
154
+ ):
155
+ """
156
+ 1) Dynamically instantiate the chosen OCR pipeline (if any)
157
+ 2) Dynamically instantiate the chosen vision‐language model
158
+ 3) Override the global get_image_description to use that model for captions
159
+ 4) Extract text & images, index into ChromaDB
160
+ """
161
+ if not docs:
162
+ raise gr.Error("No documents to process")
163
+
164
+ # ——— 1) Set up OCR if requested ————————————————
165
+ if do_ocr == "Get Text With OCR":
166
+ db_m, crnn_m = OCR_CHOICES[ocr_choice]
167
+ local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
168
+ else:
169
+ local_ocr = None
170
+
171
+ # ——— 2) Set up vision‐language model —————————————
172
+ proc = LlavaNextProcessor.from_pretrained(vlm_choice)
173
+ vis = LlavaNextForConditionalGeneration.from_pretrained(
174
+ vlm_choice,
175
+ torch_dtype=torch.float16,
176
+ low_cpu_mem_usage=True
177
+ ).to("cpu")
178
+
179
+ # ——— 3) Monkey‐patch global get_image_description ————
180
+ def describe(img: Image.Image) -> str:
181
+ torch.cuda.empty_cache(); gc.collect()
182
+ prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
183
+ inputs = proc(prompt, img, return_tensors="pt").to("cpu")
184
+ output = vis.generate(**inputs, max_new_tokens=100)
185
+ return proc.decode(output[0], skip_special_tokens=True)
186
+
187
+ global get_image_description
188
+ get_image_description = describe
189
+
190
+ # ——— 4) Extract text & images ————————————————
191
+ progress(0.2, "Extracting text and images…")
192
+ all_text, images, names = "", [], []
193
+ for path in docs:
194
+ if local_ocr:
195
+ pdf = DocumentFile.from_pdf(path)
196
+ res = local_ocr(pdf)
197
+ all_text += result_to_text(res, as_text=True) + "\n\n"
198
+ else:
199
+ txt = PdfReader(path).pages[0].extract_text() or ""
200
+ all_text += "\n\n" + txt + "\n\n"
201
+
202
+ if include_images == "Include Images":
203
+ imgs = extract_images([path])
204
+ images.extend(imgs)
205
+ names.extend([os.path.basename(path)] * len(imgs))
206
+
207
+ # ——— 5) Index into vector DB ————————————————
208
+ progress(0.6, "Indexing in vector DB…")
209
+ vdb = get_vectordb(all_text, images, names)
210
+
211
+ session["processed"] = True
212
+ sample_imgs = images[:4] if include_images == "Include Images" else []
213
+ return (
214
+ vdb,
215
+ session,
216
+ gr.Row(visible=True),
217
+ all_text[:2000] + "...",
218
+ sample_imgs,
219
+ "<h3>Done!</h3>"
220
+ )
221
+ # Chat function
222
+ def conversation(
223
+ vdb, question: str, num_ctx, img_ctx,
224
+ history: list, temp: float, max_tok: int, model_id: str
225
+ ):
226
+ # 0) Cast the context sliders to ints
227
+ num_ctx = int(num_ctx)
228
+ img_ctx = int(img_ctx)
229
+
230
+ # 1) Guard: must have extracted first
231
+ if vdb is None:
232
+ raise gr.Error("Please extract data first")
233
+
234
+ # 2) Instantiate the chosen HF endpoint
235
+ llm = HuggingFaceEndpoint(
236
+ repo_id=model_id,
237
+ temperature=temp,
238
+ max_new_tokens=max_tok,
239
+ huggingfacehub_api_token=HF_TOKEN
240
+ )
241
+
242
+ # 3) Query text collection
243
+ text_col = vdb.get_collection("text_db")
244
+ docs = text_col.query(
245
+ query_texts=[question],
246
+ n_results=num_ctx, # now an int
247
+ include=["documents"]
248
+ )["documents"][0]
249
+
250
+ # 4) Query image collection
251
+ img_col = vdb.get_collection("image_db")
252
+ img_q = img_col.query(
253
+ query_texts=[question],
254
+ n_results=img_ctx, # now an int
255
+ include=["metadatas", "documents"]
256
+ )
257
+ # … rest unchanged …
258
+ images, img_descs = [], img_q["documents"][0] or ["No images found"]
259
+ for meta in img_q["metadatas"][0]:
260
+ b64 = meta.get("image", "")
261
+ try:
262
+ images.append(Image.open(io.BytesIO(base64.b64decode(b64))))
263
+ except:
264
+ pass
265
+ img_desc = "\n".join(img_descs)
266
+
267
+ # 5) Build prompt
268
+ prompt = PromptTemplate(
269
+ template="""
270
+ Context:
271
+ {text}
272
+
273
+ Included Images:
274
+ {img_desc}
275
+
276
+ Question:
277
+ {q}
278
+
279
+ Answer:
280
+ """,
281
+ input_variables=["text", "img_desc", "q"],
282
+ )
283
+ context = "\n\n".join(docs)
284
+ user_input = prompt.format(text=context, img_desc=img_desc, q=question)
285
+
286
+ # 6) Call the model with error handling
287
+ try:
288
+ answer = llm.invoke(user_input)
289
+ except HfHubHTTPError as e:
290
+ if e.response.status_code == 404:
291
+ answer = f"❌ Model `{model_id}` not hosted on HF Inference API."
292
+ else:
293
+ answer = f"⚠️ HF API error: {e}"
294
+ except Exception as e:
295
+ answer = f"⚠️ Unexpected error: {e}"
296
+
297
+ # 7) Append to history
298
+ new_history = history + [
299
+ {"role":"user", "content": question},
300
+ {"role":"assistant","content": answer}
301
+ ]
302
+
303
+ # 8) Return updated history, docs, images
304
+ return new_history, docs, images
305
+
306
+
307
+
308
+ # ─────────────────────────────────────────────────────────────────────────────
309
+ # Gradio UI
310
+ CSS = """
311
+ footer {visibility:hidden;}
312
+ """
313
+
314
+ MODEL_OPTIONS = [
315
+ "HuggingFaceH4/zephyr-7b-beta",
316
+ "mistralai/Mistral-7B-Instruct-v0.2",
317
+ "openchat/openchat-3.5-0106",
318
+ "google/gemma-7b-it",
319
+ "deepseek-ai/deepseek-llm-7b-chat",
320
+ "microsoft/Phi-3-mini-4k-instruct",
321
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
322
+ "Qwen/Qwen1.5-7B-Chat",
323
+ "tiiuae/falcon-7b-instruct", # Falcon 7B Instruct
324
+ "bigscience/bloomz-7b1", # BLOOMZ 7B
325
+ "facebook/opt-2.7b",
326
+ ]
327
+
328
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
329
+ vdb_state = gr.State()
330
+ session_state = gr.State({})
331
+
332
+ # ─── Welcome Screen ─────────────────────────────────────────────
333
+ with gr.Column(visible=True) as welcome_col:
334
+
335
+ gr.Markdown(
336
+ f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>",
337
+ elem_id="welcome_md"
338
+ )
339
+ start_btn = gr.Button("🚀 Start")
340
+
341
+ # ─── Main App (hidden until Start is clicked) ───────────────────
342
+ with gr.Column(visible=False) as app_col:
343
+ gr.Markdown("## 📚 Multimodal Chat-PDF Playground")
344
+
345
+ with gr.Tabs():
346
+ # Tab 1: Upload & Extract
347
+ with gr.TabItem("1. Upload & Extract"):
348
+ docs = gr.File(
349
+ file_count="multiple",
350
+ file_types=[".pdf"],
351
+ label="Upload PDFs"
352
+ )
353
+ include_dd = gr.Radio(
354
+ ["Include Images", "Exclude Images"],
355
+ value="Exclude Images",
356
+ label="Images"
357
+ )
358
+ ocr_dd = gr.Dropdown(
359
+ choices=[
360
+ "db_resnet50 + crnn_mobilenet_v3_large",
361
+ "db_resnet50 + crnn_resnet31"
362
+ ],
363
+ value="db_resnet50 + crnn_mobilenet_v3_large",
364
+ label="OCR Model"
365
+ )
366
+ vlm_dd = gr.Dropdown(
367
+ choices=[
368
+ "llava-hf/llava-v1.6-mistral-7b-hf",
369
+ "llava-hf/llava-v1.5-mistral-7b"
370
+ ],
371
+ value="llava-hf/llava-v1.6-mistral-7b-hf",
372
+ label="Vision-Language Model"
373
+ )
374
+ extract_btn = gr.Button("Extract")
375
+ preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
376
+ preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
377
+
378
+ extract_btn.click(
379
+ extract_data_from_pdfs,
380
+ inputs=[
381
+ docs,
382
+ session_state,
383
+ include_dd,
384
+ gr.Radio(
385
+ ["Get Text With OCR", "Get Available Text Only"],
386
+ value="Get Available Text Only",
387
+ label="OCR"
388
+ ),
389
+ ocr_dd,
390
+ vlm_dd
391
+ ],
392
+ outputs=[
393
+ vdb_state,
394
+ session_state,
395
+ gr.Row(visible=False),
396
+ preview_text,
397
+ preview_img,
398
+ gr.HTML()
399
+ ]
400
+ )
401
+
402
+ # Tab 2: Chat
403
+ with gr.TabItem("2. Chat"):
404
+ with gr.Row():
405
+ with gr.Column(scale=3):
406
+ chat = gr.Chatbot(type="messages", label="Chat")
407
+ msg = gr.Textbox(
408
+ placeholder="Ask about your PDF...",
409
+ label="Your question"
410
+ )
411
+ send = gr.Button("Send")
412
+ with gr.Column(scale=1):
413
+ model_dd = gr.Dropdown(
414
+ MODEL_OPTIONS,
415
+ value=MODEL_OPTIONS[0],
416
+ label="Choose Chat Model"
417
+ )
418
+ num_ctx = gr.Slider(1,20,value=3,label="Text Contexts")
419
+ img_ctx = gr.Slider(1,10,value=2,label="Image Contexts")
420
+ temp = gr.Slider(0.1,1.0,step=0.1,value=0.4,label="Temperature")
421
+ max_tok = gr.Slider(10,1000,step=10,value=200,label="Max Tokens")
422
+
423
+ send.click(
424
+ conversation,
425
+ inputs=[
426
+ vdb_state,
427
+ msg,
428
+ num_ctx,
429
+ img_ctx,
430
+ chat,
431
+ temp,
432
+ max_tok,
433
+ model_dd
434
+ ],
435
+ outputs=[
436
+ chat,
437
+ gr.Dataframe(),
438
+ gr.Gallery(label="Relevant Images", rows=2, value=[])
439
+ ]
440
+ )
441
+
442
+ # Footer inside app_col
443
+ gr.HTML("<center>Made with ❤️ by Zamal</center>")
444
+
445
+ # ─── Wire the Start button ───────────────────────────────────────
446
+ start_btn.click(
447
+ fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
448
+ inputs=[], outputs=[welcome_col, app_col]
449
+ )
450
+
451
+ if __name__ == "__main__":
452
+ demo.launch()
welcome_text.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # welcome_text.py
2
+
3
+ WELCOME_INTRO = """
4
+ # 📄 Welcome to the Multimodal Chat-PDF Playground
5
+
6
+ This space is designed to let you **chat with your PDFs** using both their text and images.
7
+
8
+ **Key features**
9
+ - 🔍 **OCR & Native Text** – choose whether to run OCR or use the PDF’s embedded text
10
+ - 🖼️ **Image Descriptions** – swap between different vision-language models to caption images
11
+ - 💬 **Chat with Context** – pick from a variety of open-access LLMs (Zephyr, Mistral, Falcon, etc.)
12
+
13
+ **Getting Started**
14
+ 1. Click **Start** below.
15
+ 2. Upload one or more PDFs.
16
+ 3. Choose your OCR & vision-language backends.
17
+ 4. Extract—wait a few seconds.
18
+ 5. Head over to the **Chat** tab and ask anything!
19
+
20
+ > Feel free to experiment: swap OCR models, try different Llava versions, or test Mistral vs. Falcon for chat.
21
+ > Have fun! 🚀
22
+ """