Spaces:
Sleeping
Sleeping
File size: 5,642 Bytes
cec5894 7608eee cec5894 614d74f cec5894 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# 1. Importing the required libraries and packages
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoImageProcessor, AutoModelForObjectDetection
# 2. Setup preprocessing and helper functions
# The model path to load (from Hugging Face)
model_path = "Saint5/rt_detrv2_finetuned_trash_box_detector_v1"
# Loading the model and the processor
image_processor = AutoImageProcessor.from_pretrained(model_path)
model = AutoModelForObjectDetection.from_pretrained(model_path)
# Set the target device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Get the id2label dictionary from the model
id2label = model.config.id2label
# Setting up colour dictionary for plotting boxes with different colours
colour_dict = {
"bin" : "green",
"trash" : "blue",
"hand" : "purple",
"trash_arm" : "yellow",
"not_trash" : "red",
"not_bin" : "red",
"not_bin" : "red",
}
# 3. Create function to predict on a given image with a given confidence threshold
def predict_on_image(image, conf_threshold):
model.eval()
# Make a prediction on target image
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
model_outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
# Post process the raw outputs from the model
results = image_processor.post_process_object_detection(model_outputs,
threshold=conf_threshold,
target_sizes=target_sizes)[0]
# Return all items in results to CPU (for display with matplotlib)
for key, value in results.items():
try:
result[key] = value.item().cpu()
except:
results[key] = value.cpu()
# 4. Draw the predictions on the target image
draw = ImageDraw.Draw(image)
# Get a font from ImageFont
font = ImageFont.load_default(size=20)
# Get a class name as text for print out
detected_class_name_text_labels = []
# Iterate through the predictions of the model and draw them on the target image
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label name
label_name = id2label[label.item()]
targ_colour = colour_dict[label_name]
detected_class_name_text_labels.append(label_name)
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_colour,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text string on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white",
font=font)
# Remove the draw each time
del draw
# 5. Create logic for outputting information message
# Setup set of targets to discover
target_items = {"trash", "bin", "hand"}
detected_items = set(detected_class_name_text_labels)
# If the items are not detected, return notification
if not detected_items and target_items:
return_string = (
f"""No trash, bin or hand detected at confidence threshold {conf_threshold}.
Try another image or lowering the confidence threshold."""
)
print(return_string)
return image, return_string
# If there are missing items, say what the missing items are
missing_items = target_items - detected_items
if missing_items:
return_string = (
f"""Detected the following items: {sorted(detected_items and target_items)} but missing the following in order to get 1 point: {sorted(missing_items)}.
If this is an error, try another image or alter the confidence threshold.
Otherwise, the model may need to be updated with better data."""
)
print(return_string)
return image, return_string
# If all target items are present
return_string = f"+1 Point!๐ช Found the following items: {sorted(detected_items)}, thank you for cleaning up the area!๐๐ฝ๐"
print(return_string)
return image, return_string
# 6. Setup the demo application to take in image, make a prediction with the model, return the image with drawn predicitons
# Write a description for the gradio interface
description = """
An object detection system that lets the user upload a picture of them holding trash in their hand and placing it in a bin. The system will be able to detect the hand, trash and the bin. If all the three items are available,
the user get 1 point!
Model used for the system is a finetuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the manually hand labelled [dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
"""
# Create the gradio interface
demo = gr.Interface(
fn = predict_on_image,
inputs = [
gr.Image(type="pil", label="Target Image"),
gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold")
],
outputs=[
gr.Image(type="pil", label="Image Output"),
gr.Text(label="Text Output")
],
title = "๐๏ธ๐ฎ Trash Object Detection Model Demo",
description=description,
examples=[
["image_examples/trash_example_1.jpeg", 0.3],
["image_examples/trash_example_2.jpeg", 0.3],
["image_examples/trash_example_3.jpeg", 0.3],
],
cache_examples=True
)
# Launch the demo
demo.launch()
|