zamal commited on
Commit
cd8c42c
·
verified ·
1 Parent(s): 5580220

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -22
app.py CHANGED
@@ -37,7 +37,8 @@ from utils import *
37
  # Load .env
38
  load_dotenv()
39
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
40
-
 
41
  # OCR + multimodal image description setup
42
  ocr_model = ocr_predictor(
43
  "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
@@ -52,9 +53,20 @@ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
52
 
53
  @spaces.GPU()
54
  def get_image_description(image: Image.Image) -> str:
55
- """Generate a one-sentence description via LlavaNext."""
 
 
 
 
 
 
 
 
 
 
56
  torch.cuda.empty_cache()
57
  gc.collect()
 
58
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
59
  inputs = processor(prompt, image, return_tensors="pt").to("cuda")
60
  output = vision_model.generate(**inputs, max_new_tokens=100)
@@ -143,42 +155,45 @@ OCR_CHOICES = {
143
  "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
144
  }
145
 
 
146
  def extract_data_from_pdfs(
147
- docs,
148
- session,
149
- include_images, # "Include Images" or "Exclude Images"
150
- do_ocr, # "Get Text With OCR" or "Get Available Text Only"
151
- ocr_choice, # key into OCR_CHOICES
152
- vlm_choice, # HF repo ID for LlavaNext
153
  progress=gr.Progress()
154
  ):
155
  """
156
  1) Dynamically instantiate the chosen OCR pipeline (if any)
157
  2) Dynamically instantiate the chosen vision‐language model
158
- 3) Override the global get_image_description to use that model for captions
159
  4) Extract text & images, index into ChromaDB
160
  """
161
  if not docs:
162
  raise gr.Error("No documents to process")
163
 
164
- # ——— 1) Set up OCR if requested ————————————————
165
  if do_ocr == "Get Text With OCR":
166
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
167
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
168
  else:
169
  local_ocr = None
170
 
171
- # ——— 2) Set up vision‐language model —————————————
 
172
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
173
- vis = LlavaNextForConditionalGeneration.from_pretrained(
174
- vlm_choice,
175
- torch_dtype=torch.float16,
176
- low_cpu_mem_usage=True
177
- ).to("cuda")
178
 
179
- # ——— 3) Monkey‐patch global get_image_description ————
180
  def describe(img: Image.Image) -> str:
181
- torch.cuda.empty_cache(); gc.collect()
 
182
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
183
  inputs = proc(prompt, img, return_tensors="pt").to("cuda")
184
  output = vis.generate(**inputs, max_new_tokens=100)
@@ -187,29 +202,35 @@ def extract_data_from_pdfs(
187
  global get_image_description
188
  get_image_description = describe
189
 
190
- # ——— 4) Extract text & images ————————————————
191
  progress(0.2, "Extracting text and images…")
192
- all_text, images, names = "", [], []
 
 
193
  for path in docs:
 
194
  if local_ocr:
195
  pdf = DocumentFile.from_pdf(path)
196
  res = local_ocr(pdf)
197
  all_text += result_to_text(res, as_text=True) + "\n\n"
198
  else:
199
  txt = PdfReader(path).pages[0].extract_text() or ""
200
- all_text += "\n\n" + txt + "\n\n"
201
 
 
202
  if include_images == "Include Images":
203
  imgs = extract_images([path])
204
  images.extend(imgs)
205
  names.extend([os.path.basename(path)] * len(imgs))
206
 
207
- # ——— 5) Index into vector DB ————————————————
208
  progress(0.6, "Indexing in vector DB…")
209
  vdb = get_vectordb(all_text, images, names)
210
 
 
211
  session["processed"] = True
212
  sample_imgs = images[:4] if include_images == "Include Images" else []
 
213
  return (
214
  vdb,
215
  session,
@@ -218,6 +239,7 @@ def extract_data_from_pdfs(
218
  sample_imgs,
219
  "<h3>Done!</h3>"
220
  )
 
221
  # Chat function
222
  def conversation(
223
  vdb, question: str, num_ctx, img_ctx,
 
37
  # Load .env
38
  load_dotenv()
39
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
40
+ processor = None
41
+ vision_model = None
42
  # OCR + multimodal image description setup
43
  ocr_model = ocr_predictor(
44
  "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
 
53
 
54
  @spaces.GPU()
55
  def get_image_description(image: Image.Image) -> str:
56
+ global processor, vision_model
57
+
58
+ # on first call, load & move to cuda
59
+ if processor is None or vision_model is None:
60
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
61
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
62
+ "llava-hf/llava-v1.6-mistral-7b-hf",
63
+ torch_dtype=torch.float16,
64
+ low_cpu_mem_usage=True
65
+ ).to("cuda")
66
+
67
  torch.cuda.empty_cache()
68
  gc.collect()
69
+
70
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
71
  inputs = processor(prompt, image, return_tensors="pt").to("cuda")
72
  output = vision_model.generate(**inputs, max_new_tokens=100)
 
155
  "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
156
  }
157
 
158
+ @spaces.GPU()
159
  def extract_data_from_pdfs(
160
+ docs: list[str],
161
+ session: dict,
162
+ include_images: str, # "Include Images" or "Exclude Images"
163
+ do_ocr: str, # "Get Text With OCR" or "Get Available Text Only"
164
+ ocr_choice: str, # key into OCR_CHOICES
165
+ vlm_choice: str, # HF repo ID for LlavaNext
166
  progress=gr.Progress()
167
  ):
168
  """
169
  1) Dynamically instantiate the chosen OCR pipeline (if any)
170
  2) Dynamically instantiate the chosen vision‐language model
171
+ 3) Monkey‐patch get_image_description to use that VL model
172
  4) Extract text & images, index into ChromaDB
173
  """
174
  if not docs:
175
  raise gr.Error("No documents to process")
176
 
177
+ # ——— 1) OCR setup (if requested) —————————————————————
178
  if do_ocr == "Get Text With OCR":
179
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
180
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
181
  else:
182
  local_ocr = None
183
 
184
+ # ——— 2) Vision‐language model setup ——————————————————
185
+ # Load processor + model *inside* the GPU worker
186
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
187
+ vis = (
188
+ LlavaNextForConditionalGeneration
189
+ .from_pretrained(vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True)
190
+ .to("cuda")
191
+ )
192
 
193
+ # ——— 3) Monkey‐patch get_image_description —————————————————
194
  def describe(img: Image.Image) -> str:
195
+ torch.cuda.empty_cache()
196
+ gc.collect()
197
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
198
  inputs = proc(prompt, img, return_tensors="pt").to("cuda")
199
  output = vis.generate(**inputs, max_new_tokens=100)
 
202
  global get_image_description
203
  get_image_description = describe
204
 
205
+ # ——— 4) Extract text & images —————————————————————
206
  progress(0.2, "Extracting text and images…")
207
+ all_text = ""
208
+ images, names = [], []
209
+
210
  for path in docs:
211
+ # text extraction
212
  if local_ocr:
213
  pdf = DocumentFile.from_pdf(path)
214
  res = local_ocr(pdf)
215
  all_text += result_to_text(res, as_text=True) + "\n\n"
216
  else:
217
  txt = PdfReader(path).pages[0].extract_text() or ""
218
+ all_text += txt + "\n\n"
219
 
220
+ # image extraction
221
  if include_images == "Include Images":
222
  imgs = extract_images([path])
223
  images.extend(imgs)
224
  names.extend([os.path.basename(path)] * len(imgs))
225
 
226
+ # ——— 5) Index into ChromaDB —————————————————————
227
  progress(0.6, "Indexing in vector DB…")
228
  vdb = get_vectordb(all_text, images, names)
229
 
230
+ # mark session done & prepare outputs
231
  session["processed"] = True
232
  sample_imgs = images[:4] if include_images == "Include Images" else []
233
+
234
  return (
235
  vdb,
236
  session,
 
239
  sample_imgs,
240
  "<h3>Done!</h3>"
241
  )
242
+
243
  # Chat function
244
  def conversation(
245
  vdb, question: str, num_ctx, img_ctx,