banao-tech commited on
Commit
99cd6de
·
verified ·
1 Parent(s): 00bda1b

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +242 -331
utils.py CHANGED
@@ -1,58 +1,49 @@
1
- """
2
- utils.py
3
-
4
- This module contains utility functions for:
5
- - Loading and processing images
6
- - Object detection with YOLO
7
- - OCR with EasyOCR / PaddleOCR
8
- - Image annotation and bounding box manipulation
9
- - Captioning / semantic parsing of detected icons
10
- """
11
-
12
  import os
13
  import io
14
  import base64
15
  import time
 
16
  import json
17
- import sys
18
- import re
19
- from typing import Tuple, List
20
 
21
- import torch
22
- import numpy as np
 
 
23
  import cv2
24
- from PIL import Image, ImageDraw, ImageFont
 
25
  from matplotlib import pyplot as plt
26
-
27
  import easyocr
28
  from paddleocr import PaddleOCR
29
- import supervision as sv
30
- import torchvision.transforms as T
31
- from torchvision.transforms import ToPILImage
32
- from torchvision.ops import box_convert
33
-
34
- # Optional: import AzureOpenAI if used
35
- from openai import AzureOpenAI
36
-
37
- # Initialize OCR readers
38
  reader = easyocr.Reader(['en'])
39
  paddle_ocr = PaddleOCR(
40
- lang='en', # other languages available
41
  use_angle_cls=False,
42
- use_gpu=False, # using cuda might conflict with PyTorch in the same process
43
  show_log=False,
44
  max_batch_size=1024,
45
  use_dilation=True, # improves accuracy
46
  det_db_score_mode='slow', # improves accuracy
47
- rec_batch_num=1024
48
- )
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
52
- """
53
- Loads the captioning model and processor.
54
- Supports either BLIP2 or Florence-2 models.
55
- """
56
  if not device:
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
58
  if model_name == "blip2":
@@ -60,53 +51,45 @@ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2
60
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
61
  if device == 'cpu':
62
  model = Blip2ForConditionalGeneration.from_pretrained(
63
- model_name_or_path, device_map=None, torch_dtype=torch.float32
64
- )
65
  else:
66
  model = Blip2ForConditionalGeneration.from_pretrained(
67
- model_name_or_path, device_map=None, torch_dtype=torch.float16
68
- ).to(device)
69
  elif model_name == "florence2":
70
- from transformers import AutoProcessor, AutoModelForCausalLM
71
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
72
  if device == 'cpu':
73
- model = AutoModelForCausalLM.from_pretrained(
74
- model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True
75
- )
76
  else:
77
- model = AutoModelForCausalLM.from_pretrained(
78
- model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True
79
- ).to(device)
80
  return {'model': model.to(device), 'processor': processor}
81
 
82
 
83
  def get_yolo_model(model_path):
84
- """
85
- Loads a YOLO model from a given model_path using ultralytics.
86
- """
87
  from ultralytics import YOLO
 
88
  model = YOLO(model_path)
89
  return model
90
 
91
 
92
  @torch.inference_mode()
93
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
94
- # Ensure batch_size is an integer
95
- if batch_size is None:
96
- batch_size = 32
97
 
98
  to_pil = ToPILImage()
99
  if starting_idx:
100
  non_ocr_boxes = filtered_boxes[starting_idx:]
101
  else:
102
  non_ocr_boxes = filtered_boxes
103
- cropped_pil_images = []
104
- for coord in non_ocr_boxes:
105
- xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
106
- ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
107
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
108
- cropped_pil_images.append(to_pil(cropped_image))
109
-
110
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
111
  if not prompt:
112
  if 'florence' in model.config.name_or_path:
@@ -116,29 +99,17 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
116
 
117
  generated_texts = []
118
  device = model.device
119
- for i in range(0, len(cropped_pil_images), batch_size):
120
- batch = cropped_pil_images[i:i + batch_size]
 
121
  if model.device.type == 'cuda':
122
- inputs = processor(images=batch, text=[prompt] * len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
123
  else:
124
- inputs = processor(images=batch, text=[prompt] * len(batch), return_tensors="pt").to(device=device)
125
  if 'florence' in model.config.name_or_path:
126
- generated_ids = model.generate(
127
- input_ids=inputs["input_ids"],
128
- pixel_values=inputs["pixel_values"],
129
- max_new_tokens=100,
130
- num_beams=3,
131
- do_sample=False
132
- )
133
  else:
134
- generated_ids = model.generate(
135
- **inputs,
136
- max_length=100,
137
- num_beams=5,
138
- no_repeat_ngram_size=2,
139
- early_stopping=True,
140
- num_return_sequences=1
141
- )
142
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
143
  generated_text = [gen.strip() for gen in generated_text]
144
  generated_texts.extend(generated_text)
@@ -147,57 +118,51 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
147
 
148
 
149
 
150
-
151
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
152
- """
153
- Generates parsed textual content for detected icons using the phi3_v model variant.
154
- """
155
  to_pil = ToPILImage()
156
  if ocr_bbox:
157
  non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
158
  else:
159
  non_ocr_boxes = filtered_boxes
160
- cropped_pil_images = []
161
- for coord in non_ocr_boxes:
162
- xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
163
- ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
164
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
165
- cropped_pil_images.append(to_pil(cropped_image))
166
 
167
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
168
  device = model.device
169
- messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
170
  prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
171
 
172
  batch_size = 5 # Number of samples per batch
173
  generated_texts = []
174
 
175
- for i in range(0, len(cropped_pil_images), batch_size):
176
- images = cropped_pil_images[i:i+batch_size]
177
  image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
178
- inputs = {'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
179
  texts = [prompt] * len(images)
180
- for idx, txt in enumerate(texts):
181
- inp = processor._convert_images_texts_to_inputs(image_inputs[idx], txt, return_tensors="pt")
182
- inputs['input_ids'].append(inp['input_ids'])
183
- inputs['attention_mask'].append(inp['attention_mask'])
184
- inputs['pixel_values'].append(inp['pixel_values'])
185
- inputs['image_sizes'].append(inp['image_sizes'])
186
- max_len = max(x.shape[1] for x in inputs['input_ids'])
187
- for idx, v in enumerate(inputs['input_ids']):
188
- pad_tensor = processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long)
189
- inputs['input_ids'][idx] = torch.cat([pad_tensor, v], dim=1)
190
- pad_att = torch.zeros(1, max_len - v.shape[1], dtype=torch.long)
191
- inputs['attention_mask'][idx] = torch.cat([pad_att, inputs['attention_mask'][idx]], dim=1)
192
  inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
193
 
194
- generation_args = {
195
- "max_new_tokens": 25,
196
- "temperature": 0.01,
197
- "do_sample": False,
198
- }
199
- generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
200
- # Remove input tokens from the generated sequence
201
  generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
202
  response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
203
  response = [res.strip('\n').strip() for res in response]
@@ -205,19 +170,7 @@ def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, captio
205
 
206
  return generated_texts
207
 
208
-
209
  def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
210
- """
211
- Removes overlapping bounding boxes based on IoU and optionally considers OCR boxes.
212
-
213
- Args:
214
- boxes: Tensor of bounding boxes (in xyxy format).
215
- iou_threshold: IoU threshold to determine overlaps.
216
- ocr_bbox: Optional list of OCR bounding boxes.
217
-
218
- Returns:
219
- Filtered boxes as a torch.Tensor.
220
- """
221
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
222
 
223
  def box_area(box):
@@ -231,30 +184,39 @@ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
231
  return max(0, x2 - x1) * max(0, y2 - y1)
232
 
233
  def IoU(box1, box2):
234
- inter = intersection_area(box1, box2)
235
- union = box_area(box1) + box_area(box2) - inter + 1e-6
236
- ratio1 = inter / box_area(box1) if box_area(box1) > 0 else 0
237
- ratio2 = inter / box_area(box2) if box_area(box2) > 0 else 0
238
- return max(inter / union, ratio1, ratio2)
 
 
 
239
 
240
  def is_inside(box1, box2):
241
- inter = intersection_area(box1, box2)
242
- return (inter / box_area(box1)) > 0.95
 
 
243
 
244
  boxes = boxes.tolist()
245
  filtered_boxes = []
246
  if ocr_bbox:
247
  filtered_boxes.extend(ocr_bbox)
 
248
  for i, box1 in enumerate(boxes):
 
249
  is_valid_box = True
250
  for j, box2 in enumerate(boxes):
 
251
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
252
  is_valid_box = False
253
  break
254
  if is_valid_box:
 
255
  if ocr_bbox:
256
- # Only add the box if it does not overlap with any OCR box
257
- if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for box3 in ocr_bbox):
258
  filtered_boxes.append(box1)
259
  else:
260
  filtered_boxes.append(box1)
@@ -262,17 +224,11 @@ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
262
 
263
 
264
  def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
265
- """
266
- Removes overlapping boxes with OCR priority.
267
-
268
- Args:
269
- boxes: List of dictionaries, each with keys: 'type', 'bbox', 'interactivity', 'content'.
270
- iou_threshold: IoU threshold for removal.
271
- ocr_bbox: List of OCR box dictionaries.
272
-
273
- Returns:
274
- A list of filtered box dictionaries.
275
- """
276
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
277
 
278
  def box_area(box):
@@ -286,130 +242,132 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
286
  return max(0, x2 - x1) * max(0, y2 - y1)
287
 
288
  def IoU(box1, box2):
289
- inter = intersection_area(box1, box2)
290
- union = box_area(box1) + box_area(box2) - inter + 1e-6
291
- ratio1 = inter / box_area(box1) if box_area(box1) > 0 else 0
292
- ratio2 = inter / box_area(box2) if box_area(box2) > 0 else 0
293
- return max(inter / union, ratio1, ratio2)
 
 
 
294
 
295
  def is_inside(box1, box2):
296
- inter = intersection_area(box1, box2)
297
- return (inter / box_area(box1)) > 0.80
 
 
298
 
 
299
  filtered_boxes = []
300
  if ocr_bbox:
301
  filtered_boxes.extend(ocr_bbox)
 
302
  for i, box1_elem in enumerate(boxes):
303
  box1 = box1_elem['bbox']
304
  is_valid_box = True
305
  for j, box2_elem in enumerate(boxes):
 
306
  box2 = box2_elem['bbox']
307
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
308
  is_valid_box = False
309
  break
310
  if is_valid_box:
 
311
  if ocr_bbox:
 
312
  box_added = False
313
  for box3_elem in ocr_bbox:
314
- box3 = box3_elem['bbox']
315
- if is_inside(box3, box1):
316
- try:
317
- filtered_boxes.append({
318
- 'type': 'text',
319
- 'bbox': box1_elem['bbox'],
320
- 'interactivity': True,
321
- 'content': box3_elem['content']
322
- })
323
- filtered_boxes.remove(box3_elem)
324
- except Exception:
 
 
 
 
 
 
 
 
 
 
325
  continue
326
- elif is_inside(box1, box3):
327
- box_added = True
328
- break
329
  if not box_added:
330
- filtered_boxes.append({
331
- 'type': 'icon',
332
- 'bbox': box1_elem['bbox'],
333
- 'interactivity': True,
334
- 'content': None
335
- })
336
  else:
337
  filtered_boxes.append(box1)
338
- return filtered_boxes # Optionally, you could return torch.tensor(filtered_boxes) if needed
339
 
340
 
341
  def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
342
- """
343
- Loads an image and applies transformations.
344
-
345
- Returns:
346
- image: Original image as a NumPy array.
347
- image_transformed: Transformed tensor.
348
- """
349
- transform = T.Compose([
350
- T.RandomResize([800], max_size=1333),
351
- T.ToTensor(),
352
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
353
- ])
354
  image_source = Image.open(image_path).convert("RGB")
355
  image = np.asarray(image_source)
356
  image_transformed, _ = transform(image_source, None)
357
  return image, image_transformed
358
 
359
 
360
- def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str],
361
- text_scale: float, text_padding=5, text_thickness=2, thickness=3) -> Tuple[np.ndarray, dict]:
362
- """
363
- Annotates an image with bounding boxes and labels.
364
- """
365
- # Validate phrases input
366
- phrases = [str(phrase) if not isinstance(phrase, str) else phrase for phrase in phrases]
367
 
 
 
 
 
 
 
 
 
 
 
368
  h, w, _ = image_source.shape
369
  boxes = boxes * torch.Tensor([w, h, w, h])
370
  xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
371
  xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
372
  detections = sv.Detections(xyxy=xyxy)
373
 
374
- labels = [f"{phrase}" for phrase in phrases]
375
 
376
- from util.box_annotator import BoxAnnotator
377
- box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,
378
- text_thickness=text_thickness, thickness=thickness)
379
  annotated_frame = image_source.copy()
380
- annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w, h))
381
 
382
  label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
383
  return annotated_frame, label_coordinates
384
 
385
 
386
-
387
  def predict(model, image, caption, box_threshold, text_threshold):
 
388
  """
389
- Uses a Hugging Face model to perform grounded object detection.
390
-
391
- Args:
392
- model: Dictionary with 'model' and 'processor'.
393
- image: Input PIL image.
394
- caption: Caption text.
395
- box_threshold: Confidence threshold for boxes.
396
- text_threshold: Threshold for text detection.
397
-
398
- Returns:
399
- boxes, logits, phrases from the detection.
400
- """
401
- model_obj, processor = model['model'], model['processor']
402
- device = model_obj.device
403
 
404
  inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
405
  with torch.no_grad():
406
- outputs = model_obj(**inputs)
407
 
408
  results = processor.post_process_grounded_object_detection(
409
  outputs,
410
  inputs.input_ids,
411
- box_threshold=box_threshold,
412
- text_threshold=text_threshold,
413
  target_sizes=[image.size[::-1]]
414
  )[0]
415
  boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
@@ -417,109 +375,78 @@ def predict(model, image, caption, box_threshold, text_threshold):
417
 
418
 
419
  def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
 
420
  """
421
- Uses a YOLO model for object detection.
422
-
423
- Args:
424
- model: YOLO model instance.
425
- image_path: Path to the image.
426
- box_threshold: Confidence threshold.
427
- imgsz: Image size for scaling (if scale_img is True).
428
- scale_img: Boolean flag to scale the image.
429
- iou_threshold: IoU threshold for non-max suppression.
430
-
431
- Returns:
432
- Bounding boxes, confidence scores, and placeholder phrases.
433
- """
434
- kwargs = {
435
- 'conf': box_threshold, # Confidence threshold
436
- 'iou': iou_threshold, # IoU threshold
437
- 'verbose': False
438
- }
439
  if scale_img:
440
- kwargs['imgsz'] = imgsz
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
- results = model.predict(image_path, **kwargs)
443
- boxes = results[0].boxes.xyxy
444
- conf = results[0].boxes.conf
445
- return boxes, conf, [str(i) for i in range(len(boxes))]
446
 
447
 
448
- def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None,
449
- text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None,
450
- ocr_text=[], use_local_semantics=True, iou_threshold=0.9, prompt=None, scale_img=False,
451
- imgsz=None, batch_size=None):
452
- """
453
- Processes an image to generate semantic (SOM) labels.
454
-
455
- Args:
456
- img_path: Path to the image.
457
- model: YOLO model for detection.
458
- BOX_TRESHOLD: Confidence threshold for box prediction.
459
- output_coord_in_ratio: If True, output coordinates in ratio.
460
- ocr_bbox: OCR bounding boxes.
461
- text_scale, text_padding: Parameters for drawing annotations.
462
- draw_bbox_config: Custom configuration for bounding box drawing.
463
- caption_model_processor: Dictionary with caption model and processor.
464
- ocr_text: List of OCR-detected texts.
465
- use_local_semantics: Whether to use local semantic processing.
466
- iou_threshold: IoU threshold for filtering overlaps.
467
- prompt: Optional caption prompt.
468
- scale_img: Whether to scale the image.
469
- imgsz: Image size for YOLO.
470
- batch_size: Batch size for captioning.
471
-
472
- Returns:
473
- Encoded annotated image, label coordinates, and filtered boxes.
474
  """
475
  image_source = Image.open(img_path).convert("RGB")
476
  w, h = image_source.size
477
  if not imgsz:
478
  imgsz = (h, w)
479
- # Run YOLO detection
480
- xyxy, logits, phrases = predict_yolo(
481
- model=model, image_path=img_path, box_threshold=BOX_TRESHOLD,
482
- imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1
483
- )
484
  xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
485
- image_source_np = np.asarray(image_source)
486
  phrases = [str(i) for i in range(len(phrases))]
487
 
488
- # Process OCR bounding boxes (if any)
 
489
  if ocr_bbox:
490
  ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
491
- ocr_bbox = ocr_bbox.tolist()
492
  else:
493
  print('no ocr bbox!!!')
494
  ocr_bbox = None
 
 
 
495
 
496
- ocr_bbox_elem = [{'type': 'text', 'bbox': box, 'interactivity': False, 'content': txt}
497
- for box, txt in zip(ocr_bbox, ocr_text)]
498
- xyxy_elem = [{'type': 'icon', 'bbox': box, 'interactivity': True, 'content': None}
499
- for box in xyxy.tolist()]
500
  filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
501
 
502
- # Sort filtered boxes so that boxes with 'content' == None are at the end
503
  filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
 
504
  starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
505
- filtered_boxes_tensor = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
506
 
507
- if batch_size is None:
508
- batch_size = 32
509
-
510
- # Generate parsed icon semantics if required
511
  if use_local_semantics:
512
  caption_model = caption_model_processor['model']
513
- if 'phi3_v' in caption_model.config.model_type:
514
- parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes_tensor, ocr_bbox, image_source_np, caption_model_processor)
515
  else:
516
- parsed_content_icon = get_parsed_content_icon(filtered_boxes_tensor, starting_idx, image_source_np, caption_model_processor, prompt=prompt, batch_size=batch_size)
517
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
518
  icon_start = len(ocr_text)
519
  parsed_content_icon_ls = []
520
- # Fill boxes with no OCR content with parsed icon content
521
- for box in filtered_boxes_elem:
522
- if box['content'] is None and parsed_content_icon:
523
  box['content'] = parsed_content_icon.pop(0)
524
  for i, txt in enumerate(parsed_content_icon):
525
  parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
@@ -528,72 +455,51 @@ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD=0.01, output_coord_in
528
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
529
  parsed_content_merged = ocr_text
530
 
531
- filtered_boxes_cxcywh = box_convert(boxes=filtered_boxes_tensor, in_fmt="xyxy", out_fmt="cxcywh")
532
- phrases = [i for i in range(len(filtered_boxes_cxcywh))]
 
533
 
534
- # Annotate image with bounding boxes and labels
535
  if draw_bbox_config:
536
- annotated_frame, label_coordinates = annotate(
537
- image_source=image_source_np, boxes=filtered_boxes_cxcywh, logits=logits, phrases=phrases, **draw_bbox_config
538
- )
539
  else:
540
- annotated_frame, label_coordinates = annotate(
541
- image_source=image_source_np, boxes=filtered_boxes_cxcywh, logits=logits, phrases=phrases,
542
- text_scale=text_scale, text_padding=text_padding
543
- )
544
 
545
  pil_img = Image.fromarray(annotated_frame)
546
  buffered = io.BytesIO()
547
  pil_img.save(buffered, format="PNG")
548
  encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
549
-
550
  if output_coord_in_ratio:
551
- label_coordinates = {k: [v[0] / w, v[1] / h, v[2] / w, v[3] / h] for k, v in label_coordinates.items()}
 
552
  assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
553
 
554
  return encoded_image, label_coordinates, filtered_boxes_elem
555
 
556
 
557
  def get_xywh(input):
558
- """
559
- Converts a bounding box from a list of two points into (x, y, width, height).
560
- """
561
- x, y = input[0][0], input[0][1]
562
- w = input[2][0] - input[0][0]
563
- h = input[2][1] - input[0][1]
564
- return int(x), int(y), int(w), int(h)
565
-
566
 
567
  def get_xyxy(input):
568
- """
569
- Converts a bounding box from a list of two points into (x, y, x2, y2).
570
- """
571
- x, y = input[0][0], input[0][1]
572
- x2, y2 = input[2][0], input[2][1]
573
- return int(x), int(y), int(x2), int(y2)
574
-
575
 
576
  def get_xywh_yolo(input):
577
- """
578
- Converts a YOLO-style bounding box (x1, y1, x2, y2) into (x, y, width, height).
579
- """
580
- x, y = input[0], input[1]
581
- w = input[2] - input[0]
582
- h = input[3] - input[1]
583
- return int(x), int(y), int(w), int(h)
584
 
585
 
586
- def check_ocr_box(image_path, display_img=True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
587
- """
588
- Runs OCR on the given image using PaddleOCR or EasyOCR and optionally displays annotated results.
589
-
590
- Returns:
591
- A tuple containing:
592
- - A tuple (text, bounding boxes)
593
- - The goal_filtering parameter (unchanged)
594
- """
595
  if use_paddleocr:
596
- text_threshold = 0.5 if easyocr_args is None else easyocr_args.get('text_threshold', 0.5)
 
 
 
597
  result = paddle_ocr.ocr(image_path, cls=False)[0]
598
  conf = [item[1] for item in result]
599
  coord = [item[0] for item in result if item[1][1] > text_threshold]
@@ -602,21 +508,26 @@ def check_ocr_box(image_path, display_img=True, output_bb_format='xywh', goal_fi
602
  if easyocr_args is None:
603
  easyocr_args = {}
604
  result = reader.readtext(image_path, **easyocr_args)
 
605
  coord = [item[0] for item in result]
606
  text = [item[1] for item in result]
607
-
608
  if display_img:
609
  opencv_img = cv2.imread(image_path)
610
  opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
611
  bb = []
612
  for item in coord:
613
  x, y, a, b = get_xywh(item)
 
614
  bb.append((x, y, a, b))
615
- cv2.rectangle(opencv_img, (x, y), (x + a, y + b), (0, 255, 0), 2)
 
 
616
  plt.imshow(opencv_img)
617
  else:
618
  if output_bb_format == 'xywh':
619
  bb = [get_xywh(item) for item in coord]
620
  elif output_bb_format == 'xyxy':
621
  bb = [get_xyxy(item) for item in coord]
622
- return (text, bb), goal_filtering
 
 
1
+ # from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import io
4
  import base64
5
  import time
6
+ from PIL import Image, ImageDraw, ImageFont
7
  import json
8
+ import requests
9
+ # utility function
10
+ import os
11
 
12
+
13
+ import json
14
+ import sys
15
+ import os
16
  import cv2
17
+ import numpy as np
18
+ # %matplotlib inline
19
  from matplotlib import pyplot as plt
 
20
  import easyocr
21
  from paddleocr import PaddleOCR
 
 
 
 
 
 
 
 
 
22
  reader = easyocr.Reader(['en'])
23
  paddle_ocr = PaddleOCR(
24
+ lang='en', # other lang also available
25
  use_angle_cls=False,
26
+ use_gpu=False, # using cuda will conflict with pytorch in the same process
27
  show_log=False,
28
  max_batch_size=1024,
29
  use_dilation=True, # improves accuracy
30
  det_db_score_mode='slow', # improves accuracy
31
+ rec_batch_num=1024)
32
+ import time
33
+ import base64
34
+
35
+ import os
36
+ import ast
37
+ import torch
38
+ from typing import Tuple, List
39
+ from torchvision.ops import box_convert
40
+ import re
41
+ from torchvision.transforms import ToPILImage
42
+ import supervision as sv
43
+ import torchvision.transforms as T
44
 
45
 
46
  def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
 
 
 
 
47
  if not device:
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
  if model_name == "blip2":
 
51
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
52
  if device == 'cpu':
53
  model = Blip2ForConditionalGeneration.from_pretrained(
54
+ model_name_or_path, device_map=None, torch_dtype=torch.float32
55
+ )
56
  else:
57
  model = Blip2ForConditionalGeneration.from_pretrained(
58
+ model_name_or_path, device_map=None, torch_dtype=torch.float16
59
+ ).to(device)
60
  elif model_name == "florence2":
61
+ from transformers import AutoProcessor, AutoModelForCausalLM
62
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
63
  if device == 'cpu':
64
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
 
 
65
  else:
66
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
 
 
67
  return {'model': model.to(device), 'processor': processor}
68
 
69
 
70
  def get_yolo_model(model_path):
 
 
 
71
  from ultralytics import YOLO
72
+ # Load the model.
73
  model = YOLO(model_path)
74
  return model
75
 
76
 
77
  @torch.inference_mode()
78
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
79
+ # Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
 
 
80
 
81
  to_pil = ToPILImage()
82
  if starting_idx:
83
  non_ocr_boxes = filtered_boxes[starting_idx:]
84
  else:
85
  non_ocr_boxes = filtered_boxes
86
+ croped_pil_image = []
87
+ for i, coord in enumerate(non_ocr_boxes):
88
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
89
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
90
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
91
+ croped_pil_image.append(to_pil(cropped_image))
92
+
93
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
94
  if not prompt:
95
  if 'florence' in model.config.name_or_path:
 
99
 
100
  generated_texts = []
101
  device = model.device
102
+ for i in range(0, len(croped_pil_image), batch_size):
103
+ start = time.time()
104
+ batch = croped_pil_image[i:i+batch_size]
105
  if model.device.type == 'cuda':
106
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
107
  else:
108
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
109
  if 'florence' in model.config.name_or_path:
110
+ generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=100,num_beams=3, do_sample=False)
 
 
 
 
 
 
111
  else:
112
+ generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
 
 
 
 
 
 
 
113
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
114
  generated_text = [gen.strip() for gen in generated_text]
115
  generated_texts.extend(generated_text)
 
118
 
119
 
120
 
 
121
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
 
 
 
122
  to_pil = ToPILImage()
123
  if ocr_bbox:
124
  non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
125
  else:
126
  non_ocr_boxes = filtered_boxes
127
+ croped_pil_image = []
128
+ for i, coord in enumerate(non_ocr_boxes):
129
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
130
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
131
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
132
+ croped_pil_image.append(to_pil(cropped_image))
133
 
134
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
135
  device = model.device
136
+ messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
137
  prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
138
 
139
  batch_size = 5 # Number of samples per batch
140
  generated_texts = []
141
 
142
+ for i in range(0, len(croped_pil_image), batch_size):
143
+ images = croped_pil_image[i:i+batch_size]
144
  image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
145
+ inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
146
  texts = [prompt] * len(images)
147
+ for i, txt in enumerate(texts):
148
+ input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
149
+ inputs['input_ids'].append(input['input_ids'])
150
+ inputs['attention_mask'].append(input['attention_mask'])
151
+ inputs['pixel_values'].append(input['pixel_values'])
152
+ inputs['image_sizes'].append(input['image_sizes'])
153
+ max_len = max([x.shape[1] for x in inputs['input_ids']])
154
+ for i, v in enumerate(inputs['input_ids']):
155
+ inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
156
+ inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
 
 
157
  inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
158
 
159
+ generation_args = {
160
+ "max_new_tokens": 25,
161
+ "temperature": 0.01,
162
+ "do_sample": False,
163
+ }
164
+ generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
165
+ # # remove input tokens
166
  generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
167
  response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
168
  response = [res.strip('\n').strip() for res in response]
 
170
 
171
  return generated_texts
172
 
 
173
  def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
 
 
 
 
 
 
 
 
 
 
 
174
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
175
 
176
  def box_area(box):
 
184
  return max(0, x2 - x1) * max(0, y2 - y1)
185
 
186
  def IoU(box1, box2):
187
+ intersection = intersection_area(box1, box2)
188
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
189
+ if box_area(box1) > 0 and box_area(box2) > 0:
190
+ ratio1 = intersection / box_area(box1)
191
+ ratio2 = intersection / box_area(box2)
192
+ else:
193
+ ratio1, ratio2 = 0, 0
194
+ return max(intersection / union, ratio1, ratio2)
195
 
196
  def is_inside(box1, box2):
197
+ # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
198
+ intersection = intersection_area(box1, box2)
199
+ ratio1 = intersection / box_area(box1)
200
+ return ratio1 > 0.95
201
 
202
  boxes = boxes.tolist()
203
  filtered_boxes = []
204
  if ocr_bbox:
205
  filtered_boxes.extend(ocr_bbox)
206
+ # print('ocr_bbox!!!', ocr_bbox)
207
  for i, box1 in enumerate(boxes):
208
+ # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
209
  is_valid_box = True
210
  for j, box2 in enumerate(boxes):
211
+ # keep the smaller box
212
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
213
  is_valid_box = False
214
  break
215
  if is_valid_box:
216
+ # add the following 2 lines to include ocr bbox
217
  if ocr_bbox:
218
+ # only add the box if it does not overlap with any ocr bbox
219
+ if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
220
  filtered_boxes.append(box1)
221
  else:
222
  filtered_boxes.append(box1)
 
224
 
225
 
226
  def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
227
+ '''
228
+ ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
229
+ boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
230
+
231
+ '''
 
 
 
 
 
 
232
  assert ocr_bbox is None or isinstance(ocr_bbox, List)
233
 
234
  def box_area(box):
 
242
  return max(0, x2 - x1) * max(0, y2 - y1)
243
 
244
  def IoU(box1, box2):
245
+ intersection = intersection_area(box1, box2)
246
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
247
+ if box_area(box1) > 0 and box_area(box2) > 0:
248
+ ratio1 = intersection / box_area(box1)
249
+ ratio2 = intersection / box_area(box2)
250
+ else:
251
+ ratio1, ratio2 = 0, 0
252
+ return max(intersection / union, ratio1, ratio2)
253
 
254
  def is_inside(box1, box2):
255
+ # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
256
+ intersection = intersection_area(box1, box2)
257
+ ratio1 = intersection / box_area(box1)
258
+ return ratio1 > 0.80
259
 
260
+ # boxes = boxes.tolist()
261
  filtered_boxes = []
262
  if ocr_bbox:
263
  filtered_boxes.extend(ocr_bbox)
264
+ # print('ocr_bbox!!!', ocr_bbox)
265
  for i, box1_elem in enumerate(boxes):
266
  box1 = box1_elem['bbox']
267
  is_valid_box = True
268
  for j, box2_elem in enumerate(boxes):
269
+ # keep the smaller box
270
  box2 = box2_elem['bbox']
271
  if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
272
  is_valid_box = False
273
  break
274
  if is_valid_box:
275
+ # add the following 2 lines to include ocr bbox
276
  if ocr_bbox:
277
+ # keep yolo boxes + prioritize ocr label
278
  box_added = False
279
  for box3_elem in ocr_bbox:
280
+ if not box_added:
281
+ box3 = box3_elem['bbox']
282
+ if is_inside(box3, box1): # ocr inside icon
283
+ # box_added = True
284
+ # delete the box3_elem from ocr_bbox
285
+ try:
286
+ filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content']})
287
+ filtered_boxes.remove(box3_elem)
288
+ # print('remove ocr bbox:', box3_elem)
289
+ except:
290
+ continue
291
+ # break
292
+ elif is_inside(box1, box3): # icon inside ocr
293
+ box_added = True
294
+ # try:
295
+ # filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
296
+ # filtered_boxes.remove(box3_elem)
297
+ # except:
298
+ # continue
299
+ break
300
+ else:
301
  continue
 
 
 
302
  if not box_added:
303
+ filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
304
+
 
 
 
 
305
  else:
306
  filtered_boxes.append(box1)
307
+ return filtered_boxes # torch.tensor(filtered_boxes)
308
 
309
 
310
  def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
311
+ transform = T.Compose(
312
+ [
313
+ T.RandomResize([800], max_size=1333),
314
+ T.ToTensor(),
315
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
316
+ ]
317
+ )
 
 
 
 
 
318
  image_source = Image.open(image_path).convert("RGB")
319
  image = np.asarray(image_source)
320
  image_transformed, _ = transform(image_source, None)
321
  return image, image_transformed
322
 
323
 
324
+ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
325
+ text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
326
+ """
327
+ This function annotates an image with bounding boxes and labels.
 
 
 
328
 
329
+ Parameters:
330
+ image_source (np.ndarray): The source image to be annotated.
331
+ boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
332
+ logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
333
+ phrases (List[str]): A list of labels for each bounding box.
334
+ text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
335
+
336
+ Returns:
337
+ np.ndarray: The annotated image.
338
+ """
339
  h, w, _ = image_source.shape
340
  boxes = boxes * torch.Tensor([w, h, w, h])
341
  xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
342
  xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
343
  detections = sv.Detections(xyxy=xyxy)
344
 
345
+ labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
346
 
347
+ from util.box_annotator import BoxAnnotator
348
+ box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
 
349
  annotated_frame = image_source.copy()
350
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
351
 
352
  label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
353
  return annotated_frame, label_coordinates
354
 
355
 
 
356
  def predict(model, image, caption, box_threshold, text_threshold):
357
+ """ Use huggingface model to replace the original model
358
  """
359
+ model, processor = model['model'], model['processor']
360
+ device = model.device
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
363
  with torch.no_grad():
364
+ outputs = model(**inputs)
365
 
366
  results = processor.post_process_grounded_object_detection(
367
  outputs,
368
  inputs.input_ids,
369
+ box_threshold=box_threshold, # 0.4,
370
+ text_threshold=text_threshold, # 0.3,
371
  target_sizes=[image.size[::-1]]
372
  )[0]
373
  boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
 
375
 
376
 
377
  def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
378
+ """ Use huggingface model to replace the original model
379
  """
380
+ # model = model['model']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  if scale_img:
382
+ result = model.predict(
383
+ source=image_path,
384
+ conf=box_threshold,
385
+ imgsz=imgsz,
386
+ iou=iou_threshold, # default 0.7
387
+ )
388
+ else:
389
+ result = model.predict(
390
+ source=image_path,
391
+ conf=box_threshold,
392
+ iou=iou_threshold, # default 0.7
393
+ )
394
+ boxes = result[0].boxes.xyxy#.tolist() # in pixel space
395
+ conf = result[0].boxes.conf
396
+ phrases = [str(i) for i in range(len(boxes))]
397
 
398
+ return boxes, conf, phrases
 
 
 
399
 
400
 
401
+ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=None):
402
+ """ ocr_bbox: list of xyxy format bbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  """
404
  image_source = Image.open(img_path).convert("RGB")
405
  w, h = image_source.size
406
  if not imgsz:
407
  imgsz = (h, w)
408
+ # print('image size:', w, h)
409
+ xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
 
 
 
410
  xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
411
+ image_source = np.asarray(image_source)
412
  phrases = [str(i) for i in range(len(phrases))]
413
 
414
+ # annotate the image with labels
415
+ h, w, _ = image_source.shape
416
  if ocr_bbox:
417
  ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
418
+ ocr_bbox=ocr_bbox.tolist()
419
  else:
420
  print('no ocr bbox!!!')
421
  ocr_bbox = None
422
+ # filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
423
+ # starting_idx = len(ocr_bbox)
424
+ # print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
425
 
426
+ ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt} for box, txt in zip(ocr_bbox, ocr_text)]
427
+ xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist()]
 
 
428
  filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
429
 
430
+ # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
431
  filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
432
+ # get the index of the first 'content': None
433
  starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
434
+ filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
435
 
436
+
437
+ # get parsed icon local semantics
 
 
438
  if use_local_semantics:
439
  caption_model = caption_model_processor['model']
440
+ if 'phi3_v' in caption_model.config.model_type:
441
+ parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
442
  else:
443
+ parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
444
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
445
  icon_start = len(ocr_text)
446
  parsed_content_icon_ls = []
447
+ # fill the filtered_boxes_elem None content with parsed_content_icon in order
448
+ for i, box in enumerate(filtered_boxes_elem):
449
+ if box['content'] is None:
450
  box['content'] = parsed_content_icon.pop(0)
451
  for i, txt in enumerate(parsed_content_icon):
452
  parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
 
455
  ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
456
  parsed_content_merged = ocr_text
457
 
458
+ filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
459
+
460
+ phrases = [i for i in range(len(filtered_boxes))]
461
 
462
+ # draw boxes
463
  if draw_bbox_config:
464
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
 
 
465
  else:
466
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
 
 
 
467
 
468
  pil_img = Image.fromarray(annotated_frame)
469
  buffered = io.BytesIO()
470
  pil_img.save(buffered, format="PNG")
471
  encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
 
472
  if output_coord_in_ratio:
473
+ # h, w, _ = image_source.shape
474
+ label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
475
  assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
476
 
477
  return encoded_image, label_coordinates, filtered_boxes_elem
478
 
479
 
480
  def get_xywh(input):
481
+ x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
482
+ x, y, w, h = int(x), int(y), int(w), int(h)
483
+ return x, y, w, h
 
 
 
 
 
484
 
485
  def get_xyxy(input):
486
+ x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
487
+ x, y, xp, yp = int(x), int(y), int(xp), int(yp)
488
+ return x, y, xp, yp
 
 
 
 
489
 
490
  def get_xywh_yolo(input):
491
+ x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
492
+ x, y, w, h = int(x), int(y), int(w), int(h)
493
+ return x, y, w, h
494
+
 
 
 
495
 
496
 
497
+ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
 
 
 
 
 
 
 
 
498
  if use_paddleocr:
499
+ if easyocr_args is None:
500
+ text_threshold = 0.5
501
+ else:
502
+ text_threshold = easyocr_args['text_threshold']
503
  result = paddle_ocr.ocr(image_path, cls=False)[0]
504
  conf = [item[1] for item in result]
505
  coord = [item[0] for item in result if item[1][1] > text_threshold]
 
508
  if easyocr_args is None:
509
  easyocr_args = {}
510
  result = reader.readtext(image_path, **easyocr_args)
511
+ # print('goal filtering pred:', result[-5:])
512
  coord = [item[0] for item in result]
513
  text = [item[1] for item in result]
514
+ # read the image using cv2
515
  if display_img:
516
  opencv_img = cv2.imread(image_path)
517
  opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
518
  bb = []
519
  for item in coord:
520
  x, y, a, b = get_xywh(item)
521
+ # print(x, y, a, b)
522
  bb.append((x, y, a, b))
523
+ cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
524
+
525
+ # Display the image
526
  plt.imshow(opencv_img)
527
  else:
528
  if output_bb_format == 'xywh':
529
  bb = [get_xywh(item) for item in coord]
530
  elif output_bb_format == 'xyxy':
531
  bb = [get_xyxy(item) for item in coord]
532
+ # print('bounding box!!!', bb)
533
+ return (text, bb), goal_filtering