File size: 5,642 Bytes
cec5894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7608eee
cec5894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614d74f
cec5894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

# 1. Importing the required libraries and packages
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoImageProcessor, AutoModelForObjectDetection

# 2. Setup preprocessing and helper functions
# The model path to load (from Hugging Face)
model_path = "Saint5/rt_detrv2_finetuned_trash_box_detector_v1"

# Loading the model and the processor
image_processor = AutoImageProcessor.from_pretrained(model_path)
model = AutoModelForObjectDetection.from_pretrained(model_path)

# Set the target device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Get the id2label dictionary from the model
id2label = model.config.id2label

# Setting up colour dictionary for plotting boxes with different colours
colour_dict = {
    "bin" : "green",
    "trash" : "blue",
    "hand" : "purple",
    "trash_arm" : "yellow",
    "not_trash" : "red",
    "not_bin" : "red",
    "not_bin" : "red",
}

# 3. Create function to predict on a given image with a given confidence threshold
def predict_on_image(image, conf_threshold):
  model.eval()

  # Make a prediction on target image
  with torch.no_grad():
    inputs = image_processor(images=[image], return_tensors="pt")
    model_outputs = model(**inputs.to(device))

    target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]

    # Post process the raw outputs from the model
    results = image_processor.post_process_object_detection(model_outputs,
                                                            threshold=conf_threshold,
                                                            target_sizes=target_sizes)[0]
  # Return all items in results to CPU (for display with matplotlib)
  for key, value in results.items():
    try:
      result[key] = value.item().cpu()
    except:
      results[key] = value.cpu()

  # 4. Draw the predictions on the target image
  draw = ImageDraw.Draw(image)

  # Get a font from ImageFont
  font = ImageFont.load_default(size=20)

  # Get a class name as text for print out
  detected_class_name_text_labels = []

  # Iterate through the predictions of the model and draw them on the target image
  for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
    # Create coordinates
    x, y, x2, y2 = tuple(box.tolist())

    # Get label name
    label_name = id2label[label.item()]
    targ_colour = colour_dict[label_name]
    detected_class_name_text_labels.append(label_name)

    # Draw the rectangle
    draw.rectangle(xy=(x, y, x2, y2),
                   outline=targ_colour,
                   width=3)

    # Create a text string to display
    text_string_to_show = f"{label_name} ({round(score.item(), 3)})"

    # Draw the text string on the image
    draw.text(xy=(x, y),
              text=text_string_to_show,
              fill="white",
              font=font)

  # Remove the draw each time
  del draw

  # 5. Create logic for outputting information message
  # Setup set of targets to discover
  target_items = {"trash", "bin", "hand"}
  detected_items = set(detected_class_name_text_labels)

  # If the items are not detected, return notification
  if not detected_items and target_items:
    return_string = (
        f"""No trash, bin or hand detected at confidence threshold {conf_threshold}.
        Try another image or lowering the confidence threshold."""
    )
    print(return_string)
    return image, return_string

  # If there are missing items, say what the missing items are
  missing_items = target_items - detected_items
  if missing_items:
    return_string = (
        f"""Detected the following items: {sorted(detected_items and target_items)} but missing the following in order to get 1 point: {sorted(missing_items)}.
        If this is an error, try another image or alter the confidence threshold.
        Otherwise, the model may need to be updated with better data."""
    )
    print(return_string)
    return image, return_string

  # If all target items are present
  return_string = f"+1 Point!๐Ÿช™ Found the following items: {sorted(detected_items)}, thank you for cleaning up the area!๐Ÿ‘๐Ÿฝ๐Ÿ˜„"
  print(return_string)
  return image, return_string

# 6. Setup the demo application to take in image, make a prediction with the model, return the image with drawn predicitons
# Write a description for the gradio interface
description = """
An object detection system that lets the user upload a picture of them holding trash in their hand and placing it in a bin. The system will be able to detect the hand, trash and the bin. If all the three items are available,
the user get 1 point!

Model used for the system is a finetuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the manually hand labelled [dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
"""

# Create the gradio interface
demo = gr.Interface(
    fn = predict_on_image,
    inputs = [
        gr.Image(type="pil", label="Target Image"),
        gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold")
    ],
    outputs=[
        gr.Image(type="pil", label="Image Output"),
        gr.Text(label="Text Output")
    ],
    title = "๐Ÿ—‘๏ธ๐Ÿšฎ Trash Object Detection Model Demo",
    description=description,
    examples=[
        ["image_examples/trash_example_1.jpeg", 0.3],
        ["image_examples/trash_example_2.jpeg", 0.3],
        ["image_examples/trash_example_3.jpeg", 0.3],
    ],
    cache_examples=True
)

# Launch the demo
demo.launch()