Saint5's picture
Update app.py
614d74f verified
# 1. Importing the required libraries and packages
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoImageProcessor, AutoModelForObjectDetection
# 2. Setup preprocessing and helper functions
# The model path to load (from Hugging Face)
model_path = "Saint5/rt_detrv2_finetuned_trash_box_detector_v1"
# Loading the model and the processor
image_processor = AutoImageProcessor.from_pretrained(model_path)
model = AutoModelForObjectDetection.from_pretrained(model_path)
# Set the target device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Get the id2label dictionary from the model
id2label = model.config.id2label
# Setting up colour dictionary for plotting boxes with different colours
colour_dict = {
"bin" : "green",
"trash" : "blue",
"hand" : "purple",
"trash_arm" : "yellow",
"not_trash" : "red",
"not_bin" : "red",
"not_bin" : "red",
}
# 3. Create function to predict on a given image with a given confidence threshold
def predict_on_image(image, conf_threshold):
model.eval()
# Make a prediction on target image
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
model_outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
# Post process the raw outputs from the model
results = image_processor.post_process_object_detection(model_outputs,
threshold=conf_threshold,
target_sizes=target_sizes)[0]
# Return all items in results to CPU (for display with matplotlib)
for key, value in results.items():
try:
result[key] = value.item().cpu()
except:
results[key] = value.cpu()
# 4. Draw the predictions on the target image
draw = ImageDraw.Draw(image)
# Get a font from ImageFont
font = ImageFont.load_default(size=20)
# Get a class name as text for print out
detected_class_name_text_labels = []
# Iterate through the predictions of the model and draw them on the target image
for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
# Create coordinates
x, y, x2, y2 = tuple(box.tolist())
# Get label name
label_name = id2label[label.item()]
targ_colour = colour_dict[label_name]
detected_class_name_text_labels.append(label_name)
# Draw the rectangle
draw.rectangle(xy=(x, y, x2, y2),
outline=targ_colour,
width=3)
# Create a text string to display
text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
# Draw the text string on the image
draw.text(xy=(x, y),
text=text_string_to_show,
fill="white",
font=font)
# Remove the draw each time
del draw
# 5. Create logic for outputting information message
# Setup set of targets to discover
target_items = {"trash", "bin", "hand"}
detected_items = set(detected_class_name_text_labels)
# If the items are not detected, return notification
if not detected_items and target_items:
return_string = (
f"""No trash, bin or hand detected at confidence threshold {conf_threshold}.
Try another image or lowering the confidence threshold."""
)
print(return_string)
return image, return_string
# If there are missing items, say what the missing items are
missing_items = target_items - detected_items
if missing_items:
return_string = (
f"""Detected the following items: {sorted(detected_items and target_items)} but missing the following in order to get 1 point: {sorted(missing_items)}.
If this is an error, try another image or alter the confidence threshold.
Otherwise, the model may need to be updated with better data."""
)
print(return_string)
return image, return_string
# If all target items are present
return_string = f"+1 Point!๐Ÿช™ Found the following items: {sorted(detected_items)}, thank you for cleaning up the area!๐Ÿ‘๐Ÿฝ๐Ÿ˜„"
print(return_string)
return image, return_string
# 6. Setup the demo application to take in image, make a prediction with the model, return the image with drawn predicitons
# Write a description for the gradio interface
description = """
An object detection system that lets the user upload a picture of them holding trash in their hand and placing it in a bin. The system will be able to detect the hand, trash and the bin. If all the three items are available,
the user get 1 point!
Model used for the system is a finetuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the manually hand labelled [dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
"""
# Create the gradio interface
demo = gr.Interface(
fn = predict_on_image,
inputs = [
gr.Image(type="pil", label="Target Image"),
gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold")
],
outputs=[
gr.Image(type="pil", label="Image Output"),
gr.Text(label="Text Output")
],
title = "๐Ÿ—‘๏ธ๐Ÿšฎ Trash Object Detection Model Demo",
description=description,
examples=[
["image_examples/trash_example_1.jpeg", 0.3],
["image_examples/trash_example_2.jpeg", 0.3],
["image_examples/trash_example_3.jpeg", 0.3],
],
cache_examples=True
)
# Launch the demo
demo.launch()