banao-tech commited on
Commit
27aaaa5
·
verified ·
1 Parent(s): e15d2e5

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +9 -2
utils.py CHANGED
@@ -91,14 +91,17 @@ def get_yolo_model(model_path):
91
 
92
  @torch.inference_mode()
93
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
94
- # Now batch_size defaults to 32 if not provided
 
 
 
95
  to_pil = ToPILImage()
96
  if starting_idx:
97
  non_ocr_boxes = filtered_boxes[starting_idx:]
98
  else:
99
  non_ocr_boxes = filtered_boxes
100
  cropped_pil_images = []
101
- for i, coord in enumerate(non_ocr_boxes):
102
  xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
103
  ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
104
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
@@ -144,6 +147,7 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
144
 
145
 
146
 
 
147
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
148
  """
149
  Generates parsed textual content for detected icons using the phi3_v model variant.
@@ -507,6 +511,9 @@ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD=0.01, output_coord_in
507
  starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
508
  filtered_boxes_tensor = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
509
 
 
 
 
510
  # Generate parsed icon semantics if required
511
  if use_local_semantics:
512
  caption_model = caption_model_processor['model']
 
91
 
92
  @torch.inference_mode()
93
  def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=32):
94
+ # Ensure batch_size is an integer
95
+ if batch_size is None:
96
+ batch_size = 32
97
+
98
  to_pil = ToPILImage()
99
  if starting_idx:
100
  non_ocr_boxes = filtered_boxes[starting_idx:]
101
  else:
102
  non_ocr_boxes = filtered_boxes
103
  cropped_pil_images = []
104
+ for coord in non_ocr_boxes:
105
  xmin, xmax = int(coord[0] * image_source.shape[1]), int(coord[2] * image_source.shape[1])
106
  ymin, ymax = int(coord[1] * image_source.shape[0]), int(coord[3] * image_source.shape[0])
107
  cropped_image = image_source[ymin:ymax, xmin:xmax, :]
 
147
 
148
 
149
 
150
+
151
  def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
152
  """
153
  Generates parsed textual content for detected icons using the phi3_v model variant.
 
511
  starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
512
  filtered_boxes_tensor = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
513
 
514
+ if batch_size is None:
515
+ batch_size = 32
516
+
517
  # Generate parsed icon semantics if required
518
  if use_local_semantics:
519
  caption_model = caption_model_processor['model']