danielritchie commited on
Commit
accfa62
·
verified ·
1 Parent(s): 9ba69a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -7,20 +7,28 @@ from ultralytics import YOLO
7
  duck_model = YOLO('https://huggingface.co/brainwavecollective/yolo8n-rubber-duck-detector/resolve/main/yolov8n_rubberducks.pt')
8
  standard_model = YOLO('yolov8n.pt')
9
 
10
- def process_image(image, model):
11
  results = model(image)
12
  processed_image = image.copy()
13
 
14
  for r in results:
15
  boxes = r.boxes
16
  for box in boxes:
17
- x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
18
- conf = float(box.conf[0])
19
  cls = int(box.cls[0])
20
  class_name = model.names[cls]
21
 
 
 
 
 
 
 
 
 
 
 
22
  cv2.rectangle(processed_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
23
- label = f"{class_name} ({conf:.2f})"
24
  cv2.putText(processed_image, label, (x1, y1-10),
25
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
26
 
@@ -29,8 +37,8 @@ def process_image(image, model):
29
  def compare_models(input_image):
30
  image = np.array(input_image)
31
 
32
- standard_image = process_image(image, standard_model)
33
- duck_image = process_image(image, duck_model)
34
 
35
  height, width = image.shape[:2]
36
  gap = 20 # Add a 20-pixel gap between images
 
7
  duck_model = YOLO('https://huggingface.co/brainwavecollective/yolo8n-rubber-duck-detector/resolve/main/yolov8n_rubberducks.pt')
8
  standard_model = YOLO('yolov8n.pt')
9
 
10
+ def process_image(image, model, is_standard_model=True):
11
  results = model(image)
12
  processed_image = image.copy()
13
 
14
  for r in results:
15
  boxes = r.boxes
16
  for box in boxes:
 
 
17
  cls = int(box.cls[0])
18
  class_name = model.names[cls]
19
 
20
+ # For standard model, only show teddy bears
21
+ if is_standard_model and class_name != "teddy bear":
22
+ continue
23
+
24
+ x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
25
+ conf = float(box.conf[0])
26
+
27
+ # Rename class to "rubber duck"
28
+ display_name = "rubber duck" if not is_standard_model else class_name
29
+
30
  cv2.rectangle(processed_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
31
+ label = f"{display_name} ({conf:.2f})"
32
  cv2.putText(processed_image, label, (x1, y1-10),
33
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
34
 
 
37
  def compare_models(input_image):
38
  image = np.array(input_image)
39
 
40
+ standard_image = process_image(image, standard_model, is_standard_model=True)
41
+ duck_image = process_image(image, duck_model, is_standard_model=False)
42
 
43
  height, width = image.shape[:2]
44
  gap = 20 # Add a 20-pixel gap between images