Víctor Sáez commited on
Commit
53e14a8
·
1 Parent(s): b09b10a

Change label

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from PIL import Image, ImageDraw
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
  import torch
5
 
@@ -8,6 +8,9 @@ model_name = "facebook/detr-resnet-50"
8
  processor = DetrImageProcessor.from_pretrained(model_name)
9
  model = DetrForObjectDetection.from_pretrained(model_name)
10
 
 
 
 
11
  # Main function: takes an image and returns it with boxes and labels
12
  def detect_objects(image):
13
  inputs = processor(images=image, return_tensors="pt")
@@ -22,12 +25,26 @@ def detect_objects(image):
22
  # Draw bounding boxes and labels on a copy of the image
23
  image_with_boxes = image.copy()
24
  draw = ImageDraw.Draw(image_with_boxes)
25
-
26
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
27
  box = [round(x, 2) for x in box.tolist()]
28
  draw.rectangle(box, outline="red", width=3)
 
 
29
  label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
30
- draw.text((box[0], box[1]), label_text, fill="white")
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  return image_with_boxes
33
 
 
1
  import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFont
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
  import torch
5
 
 
8
  processor = DetrImageProcessor.from_pretrained(model_name)
9
  model = DetrForObjectDetection.from_pretrained(model_name)
10
 
11
+ # Load default font
12
+ font = ImageFont.load_default()
13
+
14
  # Main function: takes an image and returns it with boxes and labels
15
  def detect_objects(image):
16
  inputs = processor(images=image, return_tensors="pt")
 
25
  # Draw bounding boxes and labels on a copy of the image
26
  image_with_boxes = image.copy()
27
  draw = ImageDraw.Draw(image_with_boxes)
28
+
29
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
30
  box = [round(x, 2) for x in box.tolist()]
31
  draw.rectangle(box, outline="red", width=3)
32
+
33
+ # Prepare label text
34
  label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
35
+
36
+ # Measure text size
37
+ text_bbox = draw.textbbox((0, 0), label_text, font=font)
38
+ text_width = text_bbox[2] - text_bbox[0]
39
+ text_height = text_bbox[3] - text_bbox[1]
40
+
41
+ # Set background rectangle for text
42
+ text_background = [
43
+ box[0], box[1] - text_height,
44
+ box[0] + text_width, box[1]
45
+ ]
46
+ draw.rectangle(text_background, fill="black") # Background
47
+ draw.text((box[0], box[1] - text_height), label_text, fill="white", font=font)
48
 
49
  return image_with_boxes
50