banao-tech commited on
Commit
e15d2e5
·
verified ·
1 Parent(s): 712f1db

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +5 -17
utils.py CHANGED
@@ -91,32 +91,19 @@ def get_yolo_model(model_path):
91
 
92
  @torch.inference_mode()
93
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
94
- """
95
- Generates parsed textual content for detected icons from the image.
96
-
97
- Args:
98
- filtered_boxes: Tensor of bounding boxes.
99
- starting_idx: Starting index for non-OCR boxes.
100
- image_source: Original image as a NumPy array.
101
- caption_model_processor: Dictionary with keys 'model' and 'processor'.
102
- prompt: Optional prompt text.
103
- batch_size: Batch size for processing.
104
-
105
- Returns:
106
- List of generated texts.
107
- """
108
  to_pil = ToPILImage()
109
  if starting_idx:
110
  non_ocr_boxes = filtered_boxes[starting_idx:]
111
  else:
112
  non_ocr_boxes = filtered_boxes
113
  cropped_pil_images = []
114
- for coord in non_ocr_boxes:
115
  xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
116
  ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
117
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
118
  cropped_pil_images.append(to_pil(cropped_image))
119
-
120
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
121
  if not prompt:
122
  if 'florence' in model.config.name_or_path:
@@ -127,7 +114,7 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
127
  generated_texts = []
128
  device = model.device
129
  for i in range(0, len(cropped_pil_images), batch_size):
130
- batch = cropped_pil_images[i:i+batch_size]
131
  if model.device.type == 'cuda':
132
  inputs = processor(images=batch, text=[prompt] * len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
133
  else:
@@ -156,6 +143,7 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
156
  return generated_texts
157
 
158
 
 
159
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
160
  """
161
  Generates parsed textual content for detected icons using the phi3_v model variant.
 
91
 
92
  @torch.inference_mode()
93
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
94
+ # Now batch_size defaults to 32 if not provided
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  to_pil = ToPILImage()
96
  if starting_idx:
97
  non_ocr_boxes = filtered_boxes[starting_idx:]
98
  else:
99
  non_ocr_boxes = filtered_boxes
100
  cropped_pil_images = []
101
+ for i, coord in enumerate(non_ocr_boxes):
102
  xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
103
  ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
104
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
105
  cropped_pil_images.append(to_pil(cropped_image))
106
+
107
  model, processor = caption_model_processor['model'], caption_model_processor['processor']
108
  if not prompt:
109
  if 'florence' in model.config.name_or_path:
 
114
  generated_texts = []
115
  device = model.device
116
  for i in range(0, len(cropped_pil_images), batch_size):
117
+ batch = cropped_pil_images[i:i + batch_size]
118
  if model.device.type == 'cuda':
119
  inputs = processor(images=batch, text=[prompt] * len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
120
  else:
 
143
  return generated_texts
144
 
145
 
146
+
147
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
148
  """
149
  Generates parsed textual content for detected icons using the phi3_v model variant.