zamal commited on
Commit
1e770e5
·
verified ·
1 Parent(s): 0a3438b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -238
app.py CHANGED
@@ -5,273 +5,141 @@ import gc
5
  from huggingface_hub.utils import HfHubHTTPError
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_huggingface import HuggingFaceEndpoint
8
- import io, base64
9
- from PIL import Image
10
- import torch
11
- import gradio as gr
12
- import spaces
13
- import numpy as np
14
- import pandas as pd
15
- import pymupdf
16
- from PIL import Image
17
- from pypdf import PdfReader
18
- from dotenv import load_dotenv
19
- import shutil
20
- from chromadb.config import Settings, DEFAULT_TENANT, DEFAULT_DATABASE
21
- from welcome_text import WELCOME_INTRO
22
-
23
  from doctr.io import DocumentFile
24
  from doctr.models import ocr_predictor
25
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
26
-
27
  import chromadb
28
  from chromadb.utils import embedding_functions
29
- from chromadb.utils.data_loaders import ImageLoader
30
-
31
- from langchain_core.prompts import PromptTemplate
32
  from langchain.text_splitter import RecursiveCharacterTextSplitter
33
- from langchain_huggingface import HuggingFaceEndpoint
34
-
35
- from utils import extract_pdfs, extract_images, clean_text, image_to_bytes
36
- from utils import *
37
 
38
  # ─────────────────────────────────────────────────────────────────────────────
39
- # Load .env
40
- load_dotenv()
41
- HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
42
  processor = None
43
  vision_model = None
44
- # hold the in-memory vectordb
45
- CURRENT_VDB = None
46
 
47
- # OCR + multimodal image description setup
48
- ocr_model = ocr_predictor(
49
- "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
 
 
 
 
50
  )
51
- processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
52
- vision_model = LlavaNextForConditionalGeneration.from_pretrained(
53
- "llava-hf/llava-v1.6-mistral-7b-hf",
54
- torch_dtype=torch.float16,
55
- low_cpu_mem_usage=True
56
- ).to("cuda")
57
-
58
-
59
- # Add at the top of your module, alongside your other globals
60
- PERSIST_DIR = "./chroma_db"
61
- if os.path.exists(PERSIST_DIR):
62
- shutil.rmtree(PERSIST_DIR)
63
-
64
- @spaces.GPU()
65
- def get_image_description(image: Image.Image) -> str:
66
- """
67
- Lazy-loads the Llava processor + model inside the GPU worker,
68
- runs captioning, and returns a one-sentence description.
69
- """
70
- global processor, vision_model
71
 
72
- # On first call, instantiate + move to CUDA
 
73
  if processor is None or vision_model is None:
74
- processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
 
 
75
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
76
- "llava-hf/llava-v1.6-mistral-7b-hf",
77
- torch_dtype=torch.float16,
78
- low_cpu_mem_usage=True
79
  ).to("cuda")
80
-
81
- torch.cuda.empty_cache()
82
- gc.collect()
83
-
84
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
85
- inputs = processor(prompt, image, return_tensors="pt").to("cuda")
86
- output = vision_model.generate(**inputs, max_new_tokens=100)
87
- return processor.decode(output[0], skip_special_tokens=True)
88
-
89
- # Vector DB setup
90
- # at top of file, alongside your other imports
91
- from chromadb.utils import embedding_functions
92
- from chromadb.utils.data_loaders import ImageLoader
93
- import chromadb
94
- from langchain.text_splitter import RecursiveCharacterTextSplitter
95
- from utils import image_to_bytes # your helper
96
-
97
- # 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim)
98
- SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
99
- model_name="all-MiniLM-L6-v2"
100
- )
101
-
102
- def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
103
- client = chromadb.EphemeralClient()
104
- # wipe old
105
- for name in ("text_db", "image_db"):
106
- if name in [c.name for c in client.list_collections()]:
107
- client.delete_collection(name)
108
-
109
- text_col = client.get_or_create_collection("text_db", embedding_function=SHARED_EMB_FN)
110
- img_col = client.get_or_create_collection(
111
- "image_db",
112
- embedding_function=SHARED_EMB_FN,
113
- metadata={"hnsw:space": "cosine"}
114
- )
115
-
116
- # add images
117
- if images:
118
- descs, metas = [], []
119
- for i, img in enumerate(images):
120
- try:
121
- cap = get_image_description(img)
122
- except:
123
- cap = "⚠️ could not describe image"
124
- descs.append(f"{img_names[i]}: {cap}")
125
- metas.append({"image": image_to_bytes(img)})
126
- img_col.add(ids=[str(i) for i in range(len(images))],
127
- documents=descs,
128
- metadatas=metas)
129
-
130
- # chunk + add text
131
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
132
- docs = splitter.create_documents([text])
133
- text_col.add(ids=[str(i) for i in range(len(docs))],
134
- documents=[d.page_content for d in docs])
135
-
136
- return client
137
-
138
-
139
 
140
-
141
-
142
- # Text extraction
143
- def result_to_text(result, as_text=False):
144
- pages = []
145
- for pg in result.pages:
146
- txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words)
147
- pages.append(clean_text(txt))
148
- return "\n\n".join(pages) if as_text else pages
149
-
150
- OCR_CHOICES = {
151
- "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"),
152
- "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
153
- }
154
-
155
- @spaces.GPU()
156
  def extract_data_from_pdfs(
157
- docs: list[str],
158
- session: dict,
159
- include_images: str,
160
- do_ocr: str,
161
- ocr_choice: str,
162
- vlm_choice: str,
163
- progress=gr.Progress()
164
  ):
165
  if not docs:
166
  raise gr.Error("No documents to process")
167
 
168
- # 1) OCR pipeline if requested
 
169
  if do_ocr == "Get Text With OCR":
170
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
171
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
172
- else:
173
- local_ocr = None
174
 
175
- # 2) Vision–language model
176
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
177
- vis = (
178
- LlavaNextForConditionalGeneration
179
- .from_pretrained(vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True)
180
- .to("cuda")
181
- )
182
 
183
- # 3) Monkey‐patch caption fn
184
  def describe(img: Image.Image) -> str:
185
  torch.cuda.empty_cache(); gc.collect()
186
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
187
  inp = proc(prompt, img, return_tensors="pt").to("cuda")
188
  out = vis.generate(**inp, max_new_tokens=100)
189
  return proc.decode(out[0], skip_special_tokens=True)
190
-
191
- global get_image_description
192
  get_image_description = describe
193
 
194
- # 4) Extract text & images
195
  progress(0.2, "Extracting text and images…")
196
- all_text = ""
197
- images, names = [], []
198
- for path in docs:
199
  if local_ocr:
200
- pdf = DocumentFile.from_pdf(path)
201
  res = local_ocr(pdf)
202
- all_text += result_to_text(res, as_text=True) + "\n\n"
203
  else:
204
- all_text += (PdfReader(path).pages[0].extract_text() or "") + "\n\n"
205
 
206
  if include_images == "Include Images":
207
- imgs = extract_images([path])
208
  images.extend(imgs)
209
- names.extend([os.path.basename(path)] * len(imgs))
210
 
211
- # 5) Build the inmemory vector DB once
212
  progress(0.6, "Indexing in vector DB…")
213
- global CURRENT_VDB
214
- CURRENT_VDB = get_vectordb(all_text, images, names)
215
-
216
- # 6) Mark session and return UI outputs
217
- session["processed"] = True
218
- sample = images[:4] if include_images == "Include Images" else []
219
- return (
220
- session,
221
- all_text[:2000] + "...",
222
- sample,
223
- "<h3>Done!</h3>"
224
- )
225
 
 
 
 
 
 
 
 
 
226
 
 
 
 
 
227
 
 
 
 
 
228
 
229
- # Chat function
230
- def conversation(
231
- session: dict,
232
- question: str,
233
- num_ctx: int,
234
- img_ctx: int,
235
- history: list,
236
- temp: float,
237
- max_tok: int,
238
- model_id: str
239
- ):
240
- """
241
- Uses the in-memory CURRENT_VDB (set by extract_data_from_pdfs) to answer the user.
242
- """
243
  global CURRENT_VDB
244
-
245
- # 0) Guard: make sure we've extracted at least once
246
  if not session.get("processed") or CURRENT_VDB is None:
247
  raise gr.Error("Please extract data first")
248
 
249
- # 1) Retrieve top-k text chunks
250
- text_col = CURRENT_VDB.get_collection("text_db")
251
- docs = text_col.query(
252
- query_texts=[question],
253
- n_results=int(num_ctx),
254
- include=["documents"]
255
- )["documents"][0]
256
-
257
- # 2) Retrieve top-k images
258
- img_col = CURRENT_VDB.get_collection("image_db")
259
- img_q = img_col.query(
260
- query_texts=[question],
261
- n_results=int(img_ctx),
262
- include=["metadatas", "documents"]
263
- )
264
  img_descs = img_q["documents"][0] or ["No images found"]
265
  images = []
266
- for meta in img_q["metadatas"][0]:
267
- b64 = meta.get("image", "")
268
- try:
269
- images.append(Image.open(io.BytesIO(base64.b64decode(b64))))
270
- except:
271
- pass
272
  img_desc = "\n".join(img_descs)
273
 
274
- # 3) Build the prompt
275
  prompt = PromptTemplate(
276
  template="""
277
  Context:
@@ -284,39 +152,23 @@ Question:
284
  {q}
285
 
286
  Answer:
287
- """,
288
- input_variables=["text", "img_desc", "q"],
289
- )
290
- user_input = prompt.format(
291
- text="\n\n".join(docs),
292
- img_desc=img_desc,
293
- q=question
294
- )
295
 
296
- # 4) Call the LLM
297
  llm = HuggingFaceEndpoint(
298
- repo_id=model_id,
299
- task="text-generation",
300
- temperature=temp,
301
- max_new_tokens=max_tok,
302
- # the client will pick up HUGGINGFACEHUB_API_TOKEN from env automatically
303
  )
304
- try:
305
- answer = llm.invoke(user_input)
306
  except HfHubHTTPError as e:
307
- if e.response.status_code == 404:
308
- answer = f"❌ Model `{model_id}` not hosted on HF Inference API."
309
- else:
310
- answer = f"⚠️ HF API error: {e}"
311
  except Exception as e:
312
- answer = f"⚠️ Unexpected error: {e}"
313
-
314
- # 5) Append to chat history and return
315
- new_history = history + [
316
- {"role": "user", "content": question},
317
- {"role": "assistant", "content": answer}
318
- ]
319
- return new_history, docs, images
320
 
321
 
322
 
 
5
  from huggingface_hub.utils import HfHubHTTPError
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_huggingface import HuggingFaceEndpoint
8
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from doctr.io import DocumentFile
10
  from doctr.models import ocr_predictor
11
+ from pypdf import PdfReader
12
+ from PIL import Image
13
  import chromadb
14
  from chromadb.utils import embedding_functions
 
 
 
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ import gradio as gr
 
 
 
17
 
18
  # ─────────────────────────────────────────────────────────────────────────────
19
+ # Globals
20
+ CURRENT_VDB = None
 
21
  processor = None
22
  vision_model = None
 
 
23
 
24
+ # OCR & V+L defaults
25
+ OCR_CHOICES = {
26
+ "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"),
27
+ "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
28
+ }
29
+ SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
30
+ model_name="all-MiniLM-L6-v2"
31
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def get_image_description(img: Image.Image) -> str:
34
+ global processor, vision_model
35
  if processor is None or vision_model is None:
36
+ # use the same default V+L model everywhere
37
+ vlm = "llava-hf/llava-v1.6-mistral-7b-hf"
38
+ processor = LlavaNextProcessor.from_pretrained(vlm)
39
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
40
+ vlm, torch_dtype=torch.float16, low_cpu_mem_usage=True
 
 
41
  ).to("cuda")
42
+ torch.cuda.empty_cache(); gc.collect()
 
 
 
43
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
44
+ inputs = processor(prompt, img, return_tensors="pt").to("cuda")
45
+ out = vision_model.generate(**inputs, max_new_tokens=100)
46
+ return processor.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def extract_data_from_pdfs(
49
+ docs, session, include_images, do_ocr, ocr_choice, vlm_choice, progress=gr.Progress()
 
 
 
 
 
 
50
  ):
51
  if not docs:
52
  raise gr.Error("No documents to process")
53
 
54
+ # 1) Optional OCR
55
+ local_ocr = None
56
  if do_ocr == "Get Text With OCR":
57
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
58
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
 
 
59
 
60
+ # 2) Prepare V+L
61
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
62
+ vis = LlavaNextForConditionalGeneration.from_pretrained(
63
+ vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True
64
+ ).to("cuda")
 
 
65
 
66
+ # 3) Patch get_image_description to use this choice
67
  def describe(img: Image.Image) -> str:
68
  torch.cuda.empty_cache(); gc.collect()
69
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
70
  inp = proc(prompt, img, return_tensors="pt").to("cuda")
71
  out = vis.generate(**inp, max_new_tokens=100)
72
  return proc.decode(out[0], skip_special_tokens=True)
73
+ global get_image_description, CURRENT_VDB
 
74
  get_image_description = describe
75
 
76
+ # 4) Pull text + images
77
  progress(0.2, "Extracting text and images…")
78
+ full_text, images, names = "", [], []
79
+ for p in docs:
 
80
  if local_ocr:
81
+ pdf = DocumentFile.from_pdf(p)
82
  res = local_ocr(pdf)
83
+ full_text += " ".join(w.value for blk in res.pages for line in blk.lines for w in line.words) + "\n\n"
84
  else:
85
+ full_text += (PdfReader(p).pages[0].extract_text() or "") + "\n\n"
86
 
87
  if include_images == "Include Images":
88
+ imgs = extract_images([p])
89
  images.extend(imgs)
90
+ names.extend([os.path.basename(p)] * len(imgs))
91
 
92
+ # 5) Build in-memory Chroma
93
  progress(0.6, "Indexing in vector DB…")
94
+ client = chromadb.EphemeralClient()
95
+ for col in ("text_db", "image_db"):
96
+ if col in [c.name for c in client.list_collections()]:
97
+ client.delete_collection(col)
98
+ text_col = client.get_or_create_collection("text_db", embedding_function=SHARED_EMB_FN)
99
+ img_col = client.get_or_create_collection("image_db", embedding_function=SHARED_EMB_FN,
100
+ metadata={"hnsw:space":"cosine"})
 
 
 
 
 
101
 
102
+ if images:
103
+ descs, metas = [], []
104
+ for i, im in enumerate(images):
105
+ cap = get_image_description(im)
106
+ descs.append(f"{names[i]}: {cap}")
107
+ metas.append({"image": image_to_bytes(im)})
108
+ img_col.add(ids=[str(i) for i in range(len(images))],
109
+ documents=descs, metadatas=metas)
110
 
111
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
112
+ docs_ = splitter.create_documents([full_text])
113
+ text_col.add(ids=[str(i) for i in range(len(docs_))],
114
+ documents=[d.page_content for d in docs_])
115
 
116
+ CURRENT_VDB = client
117
+ session["processed"] = True
118
+ sample = images[:4] if include_images=="Include Images" else []
119
+ return session, full_text[:2000]+"...", sample, "<h3>Done!</h3>"
120
 
121
+ def conversation(session, question, num_ctx, img_ctx, history, temp, max_tok, model_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  global CURRENT_VDB
 
 
123
  if not session.get("processed") or CURRENT_VDB is None:
124
  raise gr.Error("Please extract data first")
125
 
126
+ # a) text retrieval
127
+ docs = CURRENT_VDB.get_collection("text_db")\
128
+ .query(query_texts=[question], n_results=int(num_ctx), include=["documents"])["documents"][0]
129
+
130
+ # b) image retrieval
131
+ img_q = CURRENT_VDB.get_collection("image_db")\
132
+ .query(query_texts=[question], n_results=int(img_ctx),
133
+ include=["metadatas","documents"])
 
 
 
 
 
 
 
134
  img_descs = img_q["documents"][0] or ["No images found"]
135
  images = []
136
+ for m in img_q["metadatas"][0]:
137
+ b = m.get("image","")
138
+ try: images.append(Image.open(io.BytesIO(base64.b64decode(b))))
139
+ except: pass
 
 
140
  img_desc = "\n".join(img_descs)
141
 
142
+ # c) prompt & LLM
143
  prompt = PromptTemplate(
144
  template="""
145
  Context:
 
152
  {q}
153
 
154
  Answer:
155
+ """, input_variables=["text","img_desc","q"])
156
+ inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question)
 
 
 
 
 
 
157
 
 
158
  llm = HuggingFaceEndpoint(
159
+ repo_id=model_id, task="text-generation",
160
+ temperature=temp, max_new_tokens=max_tok,
161
+ huggingfacehub_api_token=HF_TOKEN
 
 
162
  )
163
+ try: ans = llm.invoke(inp)
 
164
  except HfHubHTTPError as e:
165
+ ans = f"❌ Model `{model_id}` not hosted." if e.response.status_code==404 else f"⚠️ HF API error: {e}"
 
 
 
166
  except Exception as e:
167
+ ans = f"⚠️ Unexpected error: {e}"
168
+
169
+ new_hist = history + [{"role":"user","content":question},
170
+ {"role":"assistant","content":ans}]
171
+ return new_hist, docs, images
 
 
 
172
 
173
 
174