# 1. Importing the required libraries and packages import gradio as gr import torch from PIL import Image, ImageDraw, ImageFont from transformers import AutoImageProcessor, AutoModelForObjectDetection # 2. Setup preprocessing and helper functions # The model path to load (from Hugging Face) model_path = "Saint5/rt_detrv2_finetuned_trash_box_detector_v1" # Loading the model and the processor image_processor = AutoImageProcessor.from_pretrained(model_path) model = AutoModelForObjectDetection.from_pretrained(model_path) # Set the target device device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) # Get the id2label dictionary from the model id2label = model.config.id2label # Setting up colour dictionary for plotting boxes with different colours colour_dict = { "bin" : "green", "trash" : "blue", "hand" : "purple", "trash_arm" : "yellow", "not_trash" : "red", "not_bin" : "red", "not_bin" : "red", } # 3. Create function to predict on a given image with a given confidence threshold def predict_on_image(image, conf_threshold): model.eval() # Make a prediction on target image with torch.no_grad(): inputs = image_processor(images=[image], return_tensors="pt") model_outputs = model(**inputs.to(device)) target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width] # Post process the raw outputs from the model results = image_processor.post_process_object_detection(model_outputs, threshold=conf_threshold, target_sizes=target_sizes)[0] # Return all items in results to CPU (for display with matplotlib) for key, value in results.items(): try: result[key] = value.item().cpu() except: results[key] = value.cpu() # 4. Draw the predictions on the target image draw = ImageDraw.Draw(image) # Get a font from ImageFont font = ImageFont.load_default(size=20) # Get a class name as text for print out detected_class_name_text_labels = [] # Iterate through the predictions of the model and draw them on the target image for box, score, label in zip(results["boxes"], results["scores"], results["labels"]): # Create coordinates x, y, x2, y2 = tuple(box.tolist()) # Get label name label_name = id2label[label.item()] targ_colour = colour_dict[label_name] detected_class_name_text_labels.append(label_name) # Draw the rectangle draw.rectangle(xy=(x, y, x2, y2), outline=targ_colour, width=3) # Create a text string to display text_string_to_show = f"{label_name} ({round(score.item(), 3)})" # Draw the text string on the image draw.text(xy=(x, y), text=text_string_to_show, fill="white", font=font) # Remove the draw each time del draw # 5. Create logic for outputting information message # Setup set of targets to discover target_items = {"trash", "bin", "hand"} detected_items = set(detected_class_name_text_labels) # If the items are not detected, return notification if not detected_items and target_items: return_string = ( f"""No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold.""" ) print(return_string) return image, return_string # If there are missing items, say what the missing items are missing_items = target_items - detected_items if missing_items: return_string = ( f"""Detected the following items: {sorted(detected_items and target_items)} but missing the following in order to get 1 point: {sorted(missing_items)}. If this is an error, try another image or alter the confidence threshold. Otherwise, the model may need to be updated with better data.""" ) print(return_string) return image, return_string # If all target items are present return_string = f"+1 Point!🪙 Found the following items: {sorted(detected_items)}, thank you for cleaning up the area!👏🏽😄" print(return_string) return image, return_string # 6. Setup the demo application to take in image, make a prediction with the model, return the image with drawn predicitons # Write a description for the gradio interface description = """ An object detection system that lets the user upload a picture of them holding trash in their hand and placing it in a bin. The system will be able to detect the hand, trash and the bin. If all the three items are available, the user get 1 point! Model used for the system is a finetuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the manually hand labelled [dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images). """ # Create the gradio interface demo = gr.Interface( fn = predict_on_image, inputs = [ gr.Image(type="pil", label="Target Image"), gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold") ], outputs=[ gr.Image(type="pil", label="Image Output"), gr.Text(label="Text Output") ], title = "🗑️🚮 Trash Object Detection Model Demo", description=description, examples=[ ["image_examples/trash_example_1.jpeg", 0.3], ["image_examples/trash_example_2.jpeg", 0.3], ["image_examples/trash_example_3.jpeg", 0.3], ], cache_examples=True ) # Launch the demo demo.launch()