mrdbourke commited on
Commit
0f10c71
·
verified ·
1 Parent(s): f057cf5

Uploading Trashify box detection model app.py

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ trashify_examples/trashify_example_1.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ trashify_examples/trashify_example_2.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ trashify_examples/trashify_example_3.jpeg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,34 @@
1
  ---
2
- title: Trashify Demo V4
3
- emoji: 👀
4
- colorFrom: yellow
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.30.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Trashify Demo V4 🚮
3
+ emoji: 🗑️
4
+ colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.29.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ # 🚮 Trashify Object Detector V4
14
+
15
+ Object detection demo to detect `trash`, `bin`, `hand`, `trash_arm`, `not_trash`, `not_bin`, `not_hand`.
16
+
17
+ Used as example for encouraging people to cleanup their local area.
18
+
19
+ If `trash`, `hand`, `bin` all detected = +1 point.
20
+
21
+ ## Dataset
22
+
23
+ All Trashify models are trained on a custom hand-labelled dataset of people picking up trash and placing it in a bin.
24
+
25
+ The dataset can be found on Hugging Face as [`mrdbourke/trashify_manual_labelled_images`](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
26
+
27
+ ## Demos
28
+
29
+ * [V1](https://huggingface.co/spaces/mrdbourke/trashify_demo_v1) = Fine-tuned [Conditional DETR](https://huggingface.co/docs/transformers/en/model_doc/conditional_detr) model trained *without* data augmentation.
30
+ * [V2](https://huggingface.co/spaces/mrdbourke/trashify_demo_v2) = Fine-tuned Conditional DETR model trained *with* data augmentation.
31
+ * [V3](https://huggingface.co/spaces/mrdbourke/trashify_demo_v3) = Fine-tuned Conditional DETR model trained *with* data augmentation (same as V2) with an NMS (Non Maximum Suppression) post-processing step.
32
+ * [V4](https://huggingface.co/spaces/mrdbourke/trashify_demo_v3) = Fine-tuned [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2) model trained *without* data augmentation or NMS post-processing (current best mAP).
33
+
34
+ TK - add links to resources to learn more
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image, ImageDraw, ImageFont
4
+
5
+ from transformers import AutoImageProcessor
6
+ from transformers import AutoModelForObjectDetection
7
+
8
+ # Note: Can load from Hugging Face or can load from local
9
+ model_save_path = "mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1"
10
+
11
+ # Load the model and preprocessor
12
+ image_processor = AutoImageProcessor.from_pretrained(model_save_path)
13
+ model = AutoModelForObjectDetection.from_pretrained(model_save_path)
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = model.to(device)
17
+
18
+ # Get the id2label dictionary from the model
19
+ id2label = model.config.id2label
20
+
21
+ # Set up a colour dictionary for plotting boxes with different colours
22
+ color_dict = {
23
+ "bin": "green",
24
+ "trash": "blue",
25
+ "hand": "purple",
26
+ "trash_arm": "yellow",
27
+ "not_trash": "red",
28
+ "not_bin": "red",
29
+ "not_hand": "red",
30
+ }
31
+
32
+ # Create helper functions for seeing if items from one list are in another
33
+ def any_in_list(list_a, list_b):
34
+ "Returns True if any item from list_a is in list_b, otherwise False."
35
+ return any(item in list_b for item in list_a)
36
+
37
+ def all_in_list(list_a, list_b):
38
+ "Returns True if all items from list_a are in list_b, otherwise False."
39
+ return all(item in list_b for item in list_a)
40
+
41
+ def predict_on_image(image, conf_threshold):
42
+ with torch.no_grad():
43
+ inputs = image_processor(images=[image], return_tensors="pt")
44
+ outputs = model(**inputs.to(device))
45
+
46
+ target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
47
+
48
+ results = image_processor.post_process_object_detection(outputs,
49
+ threshold=conf_threshold,
50
+ target_sizes=target_sizes)[0]
51
+ # Return all items in results to CPU
52
+ for key, value in results.items():
53
+ try:
54
+ results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
55
+ except:
56
+ results[key] = value.cpu()
57
+
58
+ # Can return results as plotted on a PIL image (then display the image)
59
+ draw = ImageDraw.Draw(image)
60
+
61
+ # Get a font from ImageFont
62
+ font = ImageFont.load_default(size=20)
63
+
64
+ # Get class names as text for print out
65
+ class_name_text_labels = []
66
+
67
+ for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
68
+ # Create coordinates
69
+ x, y, x2, y2 = tuple(box.tolist())
70
+
71
+ # Get label_name
72
+ label_name = id2label[label.item()]
73
+ targ_color = color_dict[label_name]
74
+ class_name_text_labels.append(label_name)
75
+
76
+ # Draw the rectangle
77
+ draw.rectangle(xy=(x, y, x2, y2),
78
+ outline=targ_color,
79
+ width=3)
80
+
81
+ # Create a text string to display
82
+ text_string_to_show = f"{label_name} ({round(score.item(), 3)})"
83
+
84
+ # Draw the text on the image
85
+ draw.text(xy=(x, y),
86
+ text=text_string_to_show,
87
+ fill="white",
88
+ font=font)
89
+
90
+ # Remove the draw each time
91
+ del draw
92
+
93
+ # Setup blank string to print out
94
+ return_string = ""
95
+
96
+ # Setup list of target items to discover
97
+ target_items = ["trash", "bin", "hand"]
98
+
99
+ # If no items detected or trash, bin, hand not in list, return notification
100
+ if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)):
101
+ return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
102
+ return image, return_string
103
+
104
+ # If there are some missing, print the ones which are missing
105
+ elif not all_in_list(list_a=target_items, list_b=class_name_text_labels):
106
+ missing_items = []
107
+ for item in target_items:
108
+ if item not in class_name_text_labels:
109
+ missing_items.append(item)
110
+ return_string = f"Detected the following items: {class_name_text_labels}. But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data."
111
+
112
+ # If all 3 trash, bin, hand occur = + 1
113
+ if all_in_list(list_a=target_items, list_b=class_name_text_labels):
114
+ return_string = f"+1! Found the following items: {class_name_text_labels}, thank you for cleaning up the area!"
115
+
116
+ print(return_string)
117
+
118
+ return image, return_string
119
+
120
+ # Create the interface
121
+ demo = gr.Interface(
122
+ fn=predict_on_image,
123
+ inputs=[
124
+ gr.Image(type="pil", label="Target Image"),
125
+ gr.Slider(minimum=0, maximum=1, value=0.3, label="Confidence Threshold")
126
+ ],
127
+ outputs=[
128
+ gr.Image(type="pil", label="Image Output"),
129
+ gr.Text(label="Text Output")
130
+ ],
131
+ title="🚮 Trashify Object Detection Demo V4",
132
+ description="Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.",
133
+ # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with
134
+ examples=[
135
+ ["trashify_examples/trashify_example_1.jpeg", 0.3],
136
+ ["trashify_examples/trashify_example_2.jpeg", 0.3],
137
+ ["trashify_examples/trashify_example_3.jpeg", 0.3],
138
+ ],
139
+ cache_examples=True
140
+ )
141
+
142
+ # Launch the demo
143
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ timm
2
+ gradio
3
+ torch
4
+ transformers
trashify_examples/trashify_example_1.jpeg ADDED

Git LFS Details

  • SHA256: b638855adc58e84e7d88fdcdaa2a44f646af663ee4710adefdd58b071a5eeb6d
  • Pointer size: 131 Bytes
  • Size of remote file: 501 kB
trashify_examples/trashify_example_2.jpeg ADDED

Git LFS Details

  • SHA256: 89ed8acec03b7890e5d2e6fa509c7e842e70a6dd9f6ad4e37d5d1431a1081be7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
trashify_examples/trashify_example_3.jpeg ADDED

Git LFS Details

  • SHA256: bdb687b82fcc52b3b758143a29fa063356e1655a6dba8f8c9e372ac8255113aa
  • Pointer size: 131 Bytes
  • Size of remote file: 927 kB