Saint5 commited on
Commit
cec5894
ยท
verified ยท
1 Parent(s): cfb049a

Uploading the trash object detection system model app.py

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ image_examples/trash_example_2.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ image_examples/trash_example_3.jpeg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,22 @@
1
  ---
2
- title: Trash Object Detection Demo
3
- emoji: ๐Ÿ†
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
- pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Trash Object Detection Demo ๐Ÿšฎ
3
+ emoji: ๐Ÿ—‘๏ธ
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
+ license: mit
10
  ---
11
 
12
+ # ๐Ÿšฎ Trash Object Detector
13
+ Object detection demo to detect `trash`, `bin`, `hand`, `trash_arm`, `not_trash`, `not_bin`, `not_hand`.
14
+
15
+ Used as example to encourage people to clean up their local area.
16
+
17
+ If `trash`, `hand`, `bin` all dected = +1 point.
18
+
19
+ ## Dataset
20
+ The model is trained on a custom dataset, hand-labelled of people picking up trash and placing it in a bin.
21
+
22
+ The dataset is found in Hugging Face as [`mrdbourke/trashify_manual_labelled_images`](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 1. Importing the required libraries and packages
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
7
+
8
+ # 2. Setup preprocessing and helper functions
9
+ # The model path to load (from Hugging Face)
10
+ model_path = "Saint5/rt_detrv2_finetuned_trash_box_detector_v1"
11
+
12
+ # Loading the model and the processor
13
+ image_processor = AutoImageProcessor.from_pretrained(model_path)
14
+ model = AutoModelForObjectDetection.from_pretrained(model_path)
15
+
16
+ # Set the target device
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model = model.to(device)
19
+
20
+ # Get the id2label dictionary from the model
21
+ id2label = model.config.id2label
22
+
23
+ # Setting up colour dictionary for plotting boxes with different colours
24
+ colour_dict = {
25
+ "bin" : "green",
26
+ "trash" : "blue",
27
+ "hand" : "purple",
28
+ "trash_arm" : "yellow",
29
+ "not_trash" : "red",
30
+ "not_bin" : "red",
31
+ "not_bin" : "red",
32
+ }
33
+
34
+ # 3. Create function to predict on a given image with a given confidence threshold
35
+ def predict_on_image(image, conf_threshold):
36
+ model.eval()
37
+
38
+ # Make a prediction on target image
39
+ with torch.no_grad():
40
+ inputs = image_processor(images=[image], return_tensors="pt")
41
+ model_outputs = model(**inputs.to(device))
42
+
43
+ target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
44
+
45
+ # Post process the raw outputs from the model
46
+ results = image_processor.post_process_object_detection(model_outputs,
47
+ threshold=conf_threshold,
48
+ target_sizes=target_sizes)[0]
49
+ # Return all items in results to CPU (for display with matplotlib)
50
+ for key, value in results.items():
51
+ try:
52
+ result[key] = value.item().cpu()
53
+ except:
54
+ results[key] = value.cpu()
55
+
56
+ # 4. Draw the predictions on the target image
57
+ draw = ImageDraw.Draw(image)
58
+
59
+ # Get a font from ImageFont
60
+ font = ImageFont.load_default(size=20)
61
+
62
+ # Get a class name as text for print out
63
+ detected_class_name_text_labels = []
64
+
65
+ # Iterate through the predictions of the model and draw them on the target image
66
+ for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
67
+ # Create coordinates
68
+ x, y, x2, y2 = tuple(box.tolist())
69
+
70
+ # Get label name
71
+ label_name = id2label[label.item()]
72
+ targ_colour = colour_dict[label_name]
73
+ detected_class_name_text_labels.append(label_name)
74
+
75
+ # Draw the rectangle
76
+ draw.Rectangle(xy=(x, y, x2, y2),
77
+ outline=targ_colour,
78
+ width=3)
79
+
80
+ # Create a text string to display
81
+ text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
82
+
83
+ # Draw the text string on the image
84
+ draw.text(xy=(x, y),
85
+ text=text_string_to_show,
86
+ fill="white",
87
+ font=font)
88
+
89
+ # Remove the draw each time
90
+ del draw
91
+
92
+ # 5. Create logic for outputting information message
93
+ # Setup set of targets to discover
94
+ target_items = {"trash", "bin", "hand"}
95
+ detected_items = set(detected_class_name_text_labels)
96
+
97
+ # If the items are not detected, return notification
98
+ if not detected_items and target_items:
99
+ return_string = (
100
+ f"""No trash, bin or hand detected at confidence threshold {conf_threshold}.
101
+ Try another image or lowering the confidence threshold."""
102
+ )
103
+ print(return_string)
104
+ return image, return_string
105
+
106
+ # If there are missing items, say what the missing items are
107
+ missing_items = target_items - detected_items
108
+ if missing_items:
109
+ return_string = (
110
+ f"""Detected the following items: {sorted(detected_items and target_items)} but missing the following in order to get 1 point: {sorted(missing_items)}.
111
+ If this is an error, try another image or alter the confidence threshold.
112
+ Otherwise, the model may need to be updated with better data."""
113
+ )
114
+ print(return_string)
115
+ return image, return_string
116
+
117
+ # If all target items are present
118
+ return_string = f"+1 Point!๐Ÿช™ Found the following items: {sorted(detected_items)}, thank you for cleaning up the area!"
119
+ print(return_string)
120
+ return image, return_string
121
+
122
+ # 6. Setup the demo application to take in image, make a prediction with the model, return the image with drawn predicitons
123
+ # Write a description for the gradio interface
124
+ description = """
125
+ An object detection system that lets the user upload a picture of them holding trash in their hand and placing it in a bin. The system will be able to detect the hand, trash and the bin. If all the three items are available,
126
+ the user get 1 point!
127
+
128
+ Model used for the system is a finetuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the manually hand labelled [dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
129
+ """
130
+
131
+ # Create the gradio interface
132
+ demo = gr.Interface(
133
+ fn = predict_on_image,
134
+ inputs = [
135
+ gr.Image(type="pil", label="Target Image"),
136
+ gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold")
137
+ ],
138
+ outputs=[
139
+ gr.Image(type="pil", label="Image Output"),
140
+ gr.Text(label="Text Output")
141
+ ],
142
+ title = "๐Ÿ—‘๏ธ๐Ÿšฎ Trash Object Detection Model Demo",
143
+ description=description,
144
+ examples=[
145
+ ["image_examples/trash_example_1.jpeg", 0.3],
146
+ ["image_examples/trash_example_2.jpeg", 0.3],
147
+ ["image_examples/trash_example_3.jpeg", 0.3],
148
+ ],
149
+ cache_examples=True
150
+ )
151
+
152
+ # Launch the demo
153
+ demo.launch()
image_examples/trash_example_1.jpeg ADDED
image_examples/trash_example_2.jpeg ADDED

Git LFS Details

  • SHA256: e1c170311bdc358d5158049f42aa38fba3794c91bcb2d11578f7eb92d924c55c
  • Pointer size: 131 Bytes
  • Size of remote file: 361 kB
image_examples/trash_example_3.jpeg ADDED

Git LFS Details

  • SHA256: 666068a4e4e92384bce54c5f9fa533ccef96da46df065e8760f03d49a04e3fd3
  • Pointer size: 131 Bytes
  • Size of remote file: 278 kB
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ timm
2
+ gradio
3
+ torch
4
+ transformers