prithivMLmods commited on
Commit
96119c1
·
verified ·
1 Parent(s): 9ef55f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -156
app.py CHANGED
@@ -54,7 +54,6 @@ model_k = VisionEncoderDecoderModel.from_pretrained(
54
  torch_dtype=torch.float16
55
  ).to(device).eval()
56
 
57
- #------------------------------------------------#
58
  # Load SmolDocling-256M-preview
59
  MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
60
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
@@ -63,7 +62,6 @@ model_x = AutoModelForVision2Seq.from_pretrained(
63
  trust_remote_code=True,
64
  torch_dtype=torch.float16
65
  ).to(device).eval()
66
- #------------------------------------------------#
67
 
68
  # Load MonkeyOCR
69
  MODEL_ID_G = "echo840/MonkeyOCR"
@@ -126,6 +124,104 @@ def downsample_video(video_path):
126
  vidcap.release()
127
  return frames
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @spaces.GPU
130
  def generate_image(model_name: str, text: str, image: Image.Image,
131
  max_new_tokens: int = 1024,
@@ -134,84 +230,82 @@ def generate_image(model_name: str, text: str, image: Image.Image,
134
  top_k: int = 50,
135
  repetition_penalty: float = 1.2):
136
  """Generate responses for image input using the selected model."""
137
- # Model selection
138
- if model_name == "Nanonets-OCR-s":
139
- processor = processor_m
140
- model = model_m
141
- elif model_name == "MonkeyOCR-Recognition":
142
- processor = processor_g
143
- model = model_g
144
- elif model_name == "SmolDocling-256M-preview":
145
- processor = processor_x
146
- model = model_x
147
- elif model_name == "ByteDance-s-Dolphin":
148
- processor = processor_k
149
- model = model_k
150
  else:
151
- yield "Invalid model selected."
152
- return
153
-
154
- if image is None:
155
- yield "Please upload an image."
156
- return
157
-
158
- # Prepare images as a list (single image for image inference)
159
- images = [image]
160
-
161
- # SmolDocling-256M specific preprocessing
162
- if model_name == "SmolDocling-256M-preview":
163
- if "OTSL" in text or "code" in text:
164
- images = [add_random_padding(img) for img in images]
165
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
166
- text = normalize_values(text, target_max=500)
167
-
168
- # Unified message structure for all models
169
- messages = [
170
- {
171
- "role": "user",
172
- "content": [{"type": "image"} for _ in images] + [
173
- {"type": "text", "text": text}
174
- ]
175
- }
176
- ]
177
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
178
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
179
-
180
- # Generation with streaming
181
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
182
- generation_kwargs = {
183
- **inputs,
184
- "streamer": streamer,
185
- "max_new_tokens": max_new_tokens,
186
- "temperature": temperature,
187
- "top_p": top_p,
188
- "top_k": top_k,
189
- "repetition_penalty": repetition_penalty,
190
- }
191
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
192
- thread.start()
193
-
194
- # Stream output and collect full response
195
- buffer = ""
196
- full_output = ""
197
- for new_text in streamer:
198
- full_output += new_text
199
- buffer += new_text.replace("<|im_end|>", "")
200
- yield buffer
201
-
202
- # SmolDocling-256M specific postprocessing
203
- if model_name == "SmolDocling-256M-preview":
204
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
205
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
206
- if "<chart>" in cleaned_output:
207
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
208
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
209
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
210
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
211
- markdown_output = doc.export_to_markdown()
212
- yield f"**MD Output:**\n\n{markdown_output}"
213
  else:
214
- yield cleaned_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  @spaces.GPU
217
  def generate_video(model_name: str, text: str, video_path: str,
@@ -221,85 +315,88 @@ def generate_video(model_name: str, text: str, video_path: str,
221
  top_k: int = 50,
222
  repetition_penalty: float = 1.2):
223
  """Generate responses for video input using the selected model."""
224
- # Model selection
225
- if model_name == "Nanonets-OCR-s":
226
- processor = processor_m
227
- model = model_m
228
- elif model_name == "MonkeyOCR-Recognition":
229
- processor = processor_g
230
- model = model_g
231
- elif model_name == "SmolDocling-256M-preview":
232
- processor = processor_x
233
- model = model_x
234
- elif model_name == "ByteDance-s-Dolphin":
235
- processor = processor_k
236
- model = model_k
237
  else:
238
- yield "Invalid model selected."
239
- return
240
-
241
- if video_path is None:
242
- yield "Please upload a video."
243
- return
244
-
245
- # Extract frames from video
246
- frames = downsample_video(video_path)
247
- images = [frame for frame, _ in frames]
248
-
249
- # SmolDocling-256M specific preprocessing
250
- if model_name == "SmolDocling-256M-preview":
251
- if "OTSL" in text or "code" in text:
252
- images = [add_random_padding(img) for img in images]
253
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
254
- text = normalize_values(text, target_max=500)
255
-
256
- # Unified message structure for all models
257
- messages = [
258
- {
259
- "role": "user",
260
- "content": [{"type": "image"} for _ in images] + [
261
- {"type": "text", "text": text}
262
- ]
263
- }
264
- ]
265
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
266
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
267
-
268
- # Generation with streaming
269
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
270
- generation_kwargs = {
271
- **inputs,
272
- "streamer": streamer,
273
- "max_new_tokens": max_new_tokens,
274
- "temperature": temperature,
275
- "top_p": top_p,
276
- "top_k": top_k,
277
- "repetition_penalty": repetition_penalty,
278
- }
279
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
280
- thread.start()
281
-
282
- # Stream output and collect full response
283
- buffer = ""
284
- full_output = ""
285
- for new_text in streamer:
286
- full_output += new_text
287
- buffer += new_text.replace("<|im_end|>", "")
288
- yield buffer
289
-
290
- # SmolDocling-256M specific postprocessing
291
- if model_name == "SmolDocling-256M-preview":
292
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
293
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
294
- if "<chart>" in cleaned_output:
295
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
296
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
297
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
298
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
299
- markdown_output = doc.export_to_markdown()
300
- yield f"**MD Output:**\n\n{markdown_output}"
301
  else:
302
- yield cleaned_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  # Define examples for image and video inference
305
  image_examples = [
@@ -325,7 +422,7 @@ css = """
325
 
326
  # Create the Gradio Interface
327
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
328
- gr.Markdown("# **[OCRNet 4x 🤗](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
329
  with gr.Row():
330
  with gr.Column():
331
  with gr.Tabs():
 
54
  torch_dtype=torch.float16
55
  ).to(device).eval()
56
 
 
57
  # Load SmolDocling-256M-preview
58
  MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
59
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
 
62
  trust_remote_code=True,
63
  torch_dtype=torch.float16
64
  ).to(device).eval()
 
65
 
66
  # Load MonkeyOCR
67
  MODEL_ID_G = "echo840/MonkeyOCR"
 
124
  vidcap.release()
125
  return frames
126
 
127
+ # Dolphin-specific functions
128
+ def model_chat(prompt, image):
129
+ """Use Dolphin model for inference."""
130
+ processor = processor_k
131
+ model = model_k
132
+ device = "cuda" if torch.cuda.is_available() else "cpu"
133
+ inputs = processor(image, return_tensors="pt").to(device)
134
+ pixel_values = inputs.pixel_values.half()
135
+ prompt_inputs = processor.tokenizer(
136
+ f"<s>{prompt} <Answer/>",
137
+ add_special_tokens=False,
138
+ return_tensors="pt"
139
+ ).to(device)
140
+ outputs = model.generate(
141
+ pixel_values=pixel_values,
142
+ decoder_input_ids=prompt_inputs.input_ids,
143
+ decoder_attention_mask=prompt_inputs.attention_mask,
144
+ min_length=1,
145
+ max_length=4096,
146
+ pad_token_id=processor.tokenizer.pad_token_id,
147
+ eos_token_id=processor.tokenizer.eos_token_id,
148
+ use_cache=True,
149
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
150
+ return_dict_in_generate=True,
151
+ do_sample=False,
152
+ num_beams=1,
153
+ repetition_penalty=1.1
154
+ )
155
+ sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
156
+ cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
157
+ return cleaned
158
+
159
+ def process_elements(layout_results, image):
160
+ """Parse layout results and extract elements from the image."""
161
+ # Placeholder parsing logic based on expected Dolphin output
162
+ # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
163
+ try:
164
+ elements = ast.literal_eval(layout_results)
165
+ except:
166
+ elements = [] # Fallback if parsing fails
167
+
168
+ recognition_results = []
169
+ reading_order = 0
170
+
171
+ for bbox, label in elements:
172
+ try:
173
+ x1, y1, x2, y2 = map(int, bbox)
174
+ cropped = image.crop((x1, y1, x2, y2))
175
+ if cropped.size[0] > 0 and cropped.size[1] > 0:
176
+ if label == "text":
177
+ text = model_chat("Read text in the image.", cropped)
178
+ recognition_results.append({
179
+ "label": label,
180
+ "bbox": [x1, y1, x2, y2],
181
+ "text": text.strip(),
182
+ "reading_order": reading_order
183
+ })
184
+ elif label == "table":
185
+ table_text = model_chat("Parse the table in the image.", cropped)
186
+ recognition_results.append({
187
+ "label": label,
188
+ "bbox": [x1, y1, x2, y2],
189
+ "text": table_text.strip(),
190
+ "reading_order": reading_order
191
+ })
192
+ elif label == "figure":
193
+ recognition_results.append({
194
+ "label": label,
195
+ "bbox": [x1, y1, x2, y2],
196
+ "text": "[Figure]", # Placeholder for figure content
197
+ "reading_order": reading_order
198
+ })
199
+ reading_order += 1
200
+ except Exception as e:
201
+ print(f"Error processing element: {e}")
202
+ continue
203
+
204
+ return recognition_results
205
+
206
+ def generate_markdown(recognition_results):
207
+ """Generate markdown from extracted elements."""
208
+ markdown = ""
209
+ for element in sorted(recognition_results, key=lambda x: x["reading_order"]):
210
+ if element["label"] == "text":
211
+ markdown += f"{element['text']}\n\n"
212
+ elif element["label"] == "table":
213
+ markdown += f"**Table:**\n{element['text']}\n\n"
214
+ elif element["label"] == "figure":
215
+ markdown += f"{element['text']}\n\n"
216
+ return markdown.strip()
217
+
218
+ def process_image_with_dolphin(image):
219
+ """Process a single image with Dolphin model."""
220
+ layout_output = model_chat("Parse the reading order of this document.", image)
221
+ elements = process_elements(layout_output, image)
222
+ markdown_content = generate_markdown(elements)
223
+ return markdown_content
224
+
225
  @spaces.GPU
226
  def generate_image(model_name: str, text: str, image: Image.Image,
227
  max_new_tokens: int = 1024,
 
230
  top_k: int = 50,
231
  repetition_penalty: float = 1.2):
232
  """Generate responses for image input using the selected model."""
233
+ if model_name == "ByteDance-s-Dolphin":
234
+ if image is None:
235
+ yield "Please upload an image."
236
+ return
237
+ markdown_content = process_image_with_dolphin(image)
238
+ yield markdown_content
 
 
 
 
 
 
 
239
  else:
240
+ # Existing logic for other models
241
+ if model_name == "Nanonets-OCR-s":
242
+ processor = processor_m
243
+ model = model_m
244
+ elif model_name == "MonkeyOCR-Recognition":
245
+ processor = processor_g
246
+ model = model_g
247
+ elif model_name == "SmolDocling-256M-preview":
248
+ processor = processor_x
249
+ model = model_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  else:
251
+ yield "Invalid model selected."
252
+ return
253
+
254
+ if image is None:
255
+ yield "Please upload an image."
256
+ return
257
+
258
+ images = [image]
259
+
260
+ if model_name == "SmolDocling-256M-preview":
261
+ if "OTSL" in text or "code" in text:
262
+ images = [add_random_padding(img) for img in images]
263
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
264
+ text = normalize_values(text, target_max=500)
265
+
266
+ messages = [
267
+ {
268
+ "role": "user",
269
+ "content": [{"type": "image"} for _ in images] + [
270
+ {"type": "text", "text": text}
271
+ ]
272
+ }
273
+ ]
274
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
275
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
276
+
277
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
278
+ generation_kwargs = {
279
+ **inputs,
280
+ "streamer": streamer,
281
+ "max_new_tokens": max_new_tokens,
282
+ "temperature": temperature,
283
+ "top_p": top_p,
284
+ "top_k": top_k,
285
+ "repetition_penalty": repetition_penalty,
286
+ }
287
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
288
+ thread.start()
289
+
290
+ buffer = ""
291
+ full_output = ""
292
+ for new_text in streamer:
293
+ full_output += new_text
294
+ buffer += new_text.replace("<|im_end|>", "")
295
+ yield buffer
296
+
297
+ if model_name == "SmolDocling-256M-preview":
298
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
299
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
300
+ if "<chart>" in cleaned_output:
301
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
302
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
303
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
304
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
305
+ markdown_output = doc.export_to_markdown()
306
+ yield f"**MD Output:**\n\n{markdown_output}"
307
+ else:
308
+ yield cleaned_output
309
 
310
  @spaces.GPU
311
  def generate_video(model_name: str, text: str, video_path: str,
 
315
  top_k: int = 50,
316
  repetition_penalty: float = 1.2):
317
  """Generate responses for video input using the selected model."""
318
+ if model_name == "ByteDance-s-Dolphin":
319
+ if video_path is None:
320
+ yield "Please upload a video."
321
+ return
322
+ frames = downsample_video(video_path)
323
+ markdown_contents = []
324
+ for frame, _ in frames:
325
+ markdown_content = process_image_with_dolphin(frame)
326
+ markdown_contents.append(markdown_content)
327
+ combined_markdown = "\n\n".join(markdown_contents)
328
+ yield combined_markdown
 
 
329
  else:
330
+ # Existing logic for other models
331
+ if model_name == "Nanonets-OCR-s":
332
+ processor = processor_m
333
+ model = model_m
334
+ elif model_name == "MonkeyOCR-Recognition":
335
+ processor = processor_g
336
+ model = model_g
337
+ elif model_name == "SmolDocling-256M-preview":
338
+ processor = processor_x
339
+ model = model_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  else:
341
+ yield "Invalid model selected."
342
+ return
343
+
344
+ if video_path is None:
345
+ yield "Please upload a video."
346
+ return
347
+
348
+ frames = downsample_video(video_path)
349
+ images = [frame for frame, _ in frames]
350
+
351
+ if model_name == "SmolDocling-256M-preview":
352
+ if "OTSL" in text or "code" in text:
353
+ images = [add_random_padding(img) for img in images]
354
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
355
+ text = normalize_values(text, target_max=500)
356
+
357
+ messages = [
358
+ {
359
+ "role": "user",
360
+ "content": [{"type": "image"} for _ in images] + [
361
+ {"type": "text", "text": text}
362
+ ]
363
+ }
364
+ ]
365
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
366
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
367
+
368
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
369
+ generation_kwargs = {
370
+ **inputs,
371
+ "streamer": streamer,
372
+ "max_new_tokens": max_new_tokens,
373
+ "temperature": temperature,
374
+ "top_p": top_p,
375
+ "top_k": top_k,
376
+ "repetition_penalty": repetition_penalty,
377
+ }
378
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
379
+ thread.start()
380
+
381
+ buffer = ""
382
+ full_output = ""
383
+ for new_text in streamer:
384
+ full_output += new_text
385
+ buffer += new_text.replace("<|im_end|>", "")
386
+ yield buffer
387
+
388
+ if model_name == "SmolDocling-256M-preview":
389
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
390
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
391
+ if "<chart>" in cleaned_output:
392
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
393
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
394
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
395
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
396
+ markdown_output = doc.export_to_markdown()
397
+ yield f"**MD Output:**\n\n{markdown_output}"
398
+ else:
399
+ yield cleaned_output
400
 
401
  # Define examples for image and video inference
402
  image_examples = [
 
422
 
423
  # Create the Gradio Interface
424
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
425
+ gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
426
  with gr.Row():
427
  with gr.Column():
428
  with gr.Tabs():