Spaces:
Sleeping
Sleeping
# 1. Importing the required libraries and packages | |
import gradio as gr | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
# 2. Setup preprocessing and helper functions | |
# The model path to load (from Hugging Face) | |
model_path = "Saint5/rt_detrv2_finetuned_trash_box_detector_v1" | |
# Loading the model and the processor | |
image_processor = AutoImageProcessor.from_pretrained(model_path) | |
model = AutoModelForObjectDetection.from_pretrained(model_path) | |
# Set the target device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
# Get the id2label dictionary from the model | |
id2label = model.config.id2label | |
# Setting up colour dictionary for plotting boxes with different colours | |
colour_dict = { | |
"bin" : "green", | |
"trash" : "blue", | |
"hand" : "purple", | |
"trash_arm" : "yellow", | |
"not_trash" : "red", | |
"not_bin" : "red", | |
"not_bin" : "red", | |
} | |
# 3. Create function to predict on a given image with a given confidence threshold | |
def predict_on_image(image, conf_threshold): | |
model.eval() | |
# Make a prediction on target image | |
with torch.no_grad(): | |
inputs = image_processor(images=[image], return_tensors="pt") | |
model_outputs = model(**inputs.to(device)) | |
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width] | |
# Post process the raw outputs from the model | |
results = image_processor.post_process_object_detection(model_outputs, | |
threshold=conf_threshold, | |
target_sizes=target_sizes)[0] | |
# Return all items in results to CPU (for display with matplotlib) | |
for key, value in results.items(): | |
try: | |
result[key] = value.item().cpu() | |
except: | |
results[key] = value.cpu() | |
# 4. Draw the predictions on the target image | |
draw = ImageDraw.Draw(image) | |
# Get a font from ImageFont | |
font = ImageFont.load_default(size=20) | |
# Get a class name as text for print out | |
detected_class_name_text_labels = [] | |
# Iterate through the predictions of the model and draw them on the target image | |
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): | |
# Create coordinates | |
x, y, x2, y2 = tuple(box.tolist()) | |
# Get label name | |
label_name = id2label[label.item()] | |
targ_colour = colour_dict[label_name] | |
detected_class_name_text_labels.append(label_name) | |
# Draw the rectangle | |
draw.rectangle(xy=(x, y, x2, y2), | |
outline=targ_colour, | |
width=3) | |
# Create a text string to display | |
text_string_to_show = f"{label_name} ({round(score.item(), 3)})" | |
# Draw the text string on the image | |
draw.text(xy=(x, y), | |
text=text_string_to_show, | |
fill="white", | |
font=font) | |
# Remove the draw each time | |
del draw | |
# 5. Create logic for outputting information message | |
# Setup set of targets to discover | |
target_items = {"trash", "bin", "hand"} | |
detected_items = set(detected_class_name_text_labels) | |
# If the items are not detected, return notification | |
if not detected_items and target_items: | |
return_string = ( | |
f"""No trash, bin or hand detected at confidence threshold {conf_threshold}. | |
Try another image or lowering the confidence threshold.""" | |
) | |
print(return_string) | |
return image, return_string | |
# If there are missing items, say what the missing items are | |
missing_items = target_items - detected_items | |
if missing_items: | |
return_string = ( | |
f"""Detected the following items: {sorted(detected_items and target_items)} but missing the following in order to get 1 point: {sorted(missing_items)}. | |
If this is an error, try another image or alter the confidence threshold. | |
Otherwise, the model may need to be updated with better data.""" | |
) | |
print(return_string) | |
return image, return_string | |
# If all target items are present | |
return_string = f"+1 Point!๐ช Found the following items: {sorted(detected_items)}, thank you for cleaning up the area!๐๐ฝ๐" | |
print(return_string) | |
return image, return_string | |
# 6. Setup the demo application to take in image, make a prediction with the model, return the image with drawn predicitons | |
# Write a description for the gradio interface | |
description = """ | |
An object detection system that lets the user upload a picture of them holding trash in their hand and placing it in a bin. The system will be able to detect the hand, trash and the bin. If all the three items are available, | |
the user get 1 point! | |
Model used for the system is a finetuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the manually hand labelled [dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images). | |
""" | |
# Create the gradio interface | |
demo = gr.Interface( | |
fn = predict_on_image, | |
inputs = [ | |
gr.Image(type="pil", label="Target Image"), | |
gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Image Output"), | |
gr.Text(label="Text Output") | |
], | |
title = "๐๏ธ๐ฎ Trash Object Detection Model Demo", | |
description=description, | |
examples=[ | |
["image_examples/trash_example_1.jpeg", 0.3], | |
["image_examples/trash_example_2.jpeg", 0.3], | |
["image_examples/trash_example_3.jpeg", 0.3], | |
], | |
cache_examples=True | |
) | |
# Launch the demo | |
demo.launch() | |