banao-tech commited on
Commit
8cc07e5
·
verified ·
1 Parent(s): b31b4ab

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +12 -12
utils.py CHANGED
@@ -375,22 +375,22 @@ def predict(model, image, caption, box_threshold, text_threshold):
375
 
376
 
377
  def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
378
- """ Use huggingface model to replace the original model
379
- """
380
- # model = model['model']
381
-
382
  kwargs = {
383
- 'conf': box_threshold,
384
- 'iou': iou_threshold,
 
385
  }
 
386
  if scale_img:
387
  kwargs['imgsz'] = imgsz
388
- result = model(image_path, **kwargs)
389
- boxes = result[0].boxes.xyxy#.tolist() # in pixel space
390
- conf = result[0].boxes.conf
391
- phrases = [str(i) for i in range(len(boxes))]
392
-
393
- return boxes, conf, phrases
394
 
395
 
396
  def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=None):
 
375
 
376
 
377
  def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
378
+ """Use YOLO model for object detection with correct parameters"""
379
+
380
+ # Should be:
 
381
  kwargs = {
382
+ 'conf_thres': box_threshold, # Correct parameter name
383
+ 'iou_thres': iou_threshold,
384
+ 'verbose': False
385
  }
386
+
387
  if scale_img:
388
  kwargs['imgsz'] = imgsz
389
+
390
+ results = model.predict(image_path, **kwargs)
391
+ boxes = results[0].boxes.xyxy
392
+ conf = results[0].boxes.conf
393
+ return boxes, conf, [str(i) for i in range(len(boxes))]
 
394
 
395
 
396
  def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=None):