mrdbourke commited on
Commit
0b9c8ac
·
verified ·
1 Parent(s): 8915efe

Uploading Trashify box detection model app.py

Browse files
README.md CHANGED
@@ -33,4 +33,4 @@ The dataset can be found on Hugging Face as [`mrdbourke/trashify_manual_labelled
33
 
34
  ## Learn more
35
 
36
- See the full end-to-end code of how this demo was built at [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
 
33
 
34
  ## Learn more
35
 
36
+ See the full end-to-end code of how this demo was built at [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
app.py CHANGED
@@ -1,17 +1,24 @@
 
 
1
  import gradio as gr
2
  import torch
3
- from PIL import Image, ImageDraw, ImageFont
4
 
5
  from transformers import AutoImageProcessor
6
  from transformers import AutoModelForObjectDetection
7
 
 
 
 
8
  # Note: Can load from Hugging Face or can load from local
9
  model_save_path = "mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1"
10
 
11
  # Load the model and preprocessor
 
12
  image_processor = AutoImageProcessor.from_pretrained(model_save_path)
13
  model = AutoModelForObjectDetection.from_pretrained(model_save_path)
14
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model = model.to(device)
17
 
@@ -31,30 +38,39 @@ color_dict = {
31
 
32
  # Create helper functions for seeing if items from one list are in another
33
  def any_in_list(list_a, list_b):
34
- "Returns True if any item from list_a is in list_b, otherwise False."
35
  return any(item in list_b for item in list_a)
36
 
37
  def all_in_list(list_a, list_b):
38
- "Returns True if all items from list_a are in list_b, otherwise False."
39
  return all(item in list_b for item in list_a)
40
 
 
41
  def predict_on_image(image, conf_threshold):
 
 
 
 
42
  with torch.no_grad():
43
  inputs = image_processor(images=[image], return_tensors="pt")
44
- outputs = model(**inputs.to(device))
45
-
46
- target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width
47
 
48
- results = image_processor.post_process_object_detection(outputs,
 
 
 
49
  threshold=conf_threshold,
50
  target_sizes=target_sizes)[0]
51
- # Return all items in results to CPU
 
52
  for key, value in results.items():
53
  try:
54
  results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
55
  except:
56
  results[key] = value.cpu()
57
 
 
 
58
  # Can return results as plotted on a PIL image (then display the image)
59
  draw = ImageDraw.Draw(image)
60
 
@@ -64,6 +80,7 @@ def predict_on_image(image, conf_threshold):
64
  # Get class names as text for print out
65
  class_name_text_labels = []
66
 
 
67
  for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
68
  # Create coordinates
69
  x, y, x2, y2 = tuple(box.tolist())
@@ -96,6 +113,8 @@ def predict_on_image(image, conf_threshold):
96
  # Setup list of target items to discover
97
  target_items = ["trash", "bin", "hand"]
98
 
 
 
99
  # If no items detected or trash, bin, hand not in list, return notification
100
  if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)):
101
  return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
@@ -117,7 +136,20 @@ def predict_on_image(image, conf_threshold):
117
 
118
  return image, return_string
119
 
120
- # Create the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  demo = gr.Interface(
122
  fn=predict_on_image,
123
  inputs=[
 
1
+
2
+ # 1. Import the required libraries and packages
3
  import gradio as gr
4
  import torch
5
+ from PIL import Image, ImageDraw, ImageFont # could also use torch utilities for drawing
6
 
7
  from transformers import AutoImageProcessor
8
  from transformers import AutoModelForObjectDetection
9
 
10
+ ### 2. Setup preprocessing and helper functions ###
11
+
12
+ # Setup target model path to load
13
  # Note: Can load from Hugging Face or can load from local
14
  model_save_path = "mrdbourke/rt_detrv2_finetuned_trashify_box_detector_v1"
15
 
16
  # Load the model and preprocessor
17
+ # Because this app.py file is running directly on Hugging Face Spaces, the model will be loaded from the Hugging Face Hub
18
  image_processor = AutoImageProcessor.from_pretrained(model_save_path)
19
  model = AutoModelForObjectDetection.from_pretrained(model_save_path)
20
 
21
+ # Set the target device (use CUDA/GPU if it is available)
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  model = model.to(device)
24
 
 
38
 
39
  # Create helper functions for seeing if items from one list are in another
40
  def any_in_list(list_a, list_b):
41
+ "Returns True if *any* item from list_a is in list_b, otherwise False."
42
  return any(item in list_b for item in list_a)
43
 
44
  def all_in_list(list_a, list_b):
45
+ "Returns True if *all* items from list_a are in list_b, otherwise False."
46
  return all(item in list_b for item in list_a)
47
 
48
+ ### 3. Create function to predict on a given image with a given confidence threshold ###
49
  def predict_on_image(image, conf_threshold):
50
+ # Make sure model is in eval mode
51
+ model.eval()
52
+
53
+ # Make a prediction on target image
54
  with torch.no_grad():
55
  inputs = image_processor(images=[image], return_tensors="pt")
56
+ model_outputs = model(**inputs.to(device))
 
 
57
 
58
+ target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # -> [batch_size, height, width]
59
+
60
+ # Post process the raw outputs from the model
61
+ results = image_processor.post_process_object_detection(model_outputs,
62
  threshold=conf_threshold,
63
  target_sizes=target_sizes)[0]
64
+
65
+ # Return all items in results to CPU (we'll want this for displaying outputs with matplotlib)
66
  for key, value in results.items():
67
  try:
68
  results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block
69
  except:
70
  results[key] = value.cpu()
71
 
72
+ ### 4. Draw the predictions on the target image ###
73
+
74
  # Can return results as plotted on a PIL image (then display the image)
75
  draw = ImageDraw.Draw(image)
76
 
 
80
  # Get class names as text for print out
81
  class_name_text_labels = []
82
 
83
+ # Iterate through the predictions of the model and draw them on the target image
84
  for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
85
  # Create coordinates
86
  x, y, x2, y2 = tuple(box.tolist())
 
113
  # Setup list of target items to discover
114
  target_items = ["trash", "bin", "hand"]
115
 
116
+ ### 5. Create logic for outputting information message ###
117
+
118
  # If no items detected or trash, bin, hand not in list, return notification
119
  if (len(class_name_text_labels) == 0) or not (any_in_list(list_a=target_items, list_b=class_name_text_labels)):
120
  return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold."
 
136
 
137
  return image, return_string
138
 
139
+ ### 6. Setup the demo application to take in image, make a prediction with our model, return the image with drawn predicitons ###
140
+
141
+ # Write description for our demo application
142
+ description = """
143
+ Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand.
144
+
145
+ Model is a fine-tuned version of [RT-DETRv2](https://huggingface.co/docs/transformers/main/en/model_doc/rt_detr_v2#transformers.RTDetrV2Config) on the [Trashify dataset](https://huggingface.co/datasets/mrdbourke/trashify_manual_labelled_images).
146
+
147
+ See the full data loading and training code on [learnhuggingface.com](https://www.learnhuggingface.com/notebooks/hugging_face_object_detection_tutorial).
148
+
149
+ This version is v4 because the first three versions were using a different model and did not perform as well, see the [README](https://huggingface.co/spaces/mrdbourke/trashify_demo_v4/blob/main/README.md) for more.
150
+ """
151
+
152
+ # Create the Gradio interface to accept an image and confidence threshold and return an image with drawn prediction boxes
153
  demo = gr.Interface(
154
  fn=predict_on_image,
155
  inputs=[
trashify_examples/trashify_example_1.jpeg CHANGED

Git LFS Details

  • SHA256: b638855adc58e84e7d88fdcdaa2a44f646af663ee4710adefdd58b071a5eeb6d
  • Pointer size: 131 Bytes
  • Size of remote file: 501 kB

Git LFS Details

  • SHA256: 1151c9b8ece79d8d0e7481c3c9801a15ca4038c12a1dd1de24a472f26b0cf72f
  • Pointer size: 130 Bytes
  • Size of remote file: 82.4 kB
trashify_examples/trashify_example_2.jpeg CHANGED

Git LFS Details

  • SHA256: 89ed8acec03b7890e5d2e6fa509c7e842e70a6dd9f6ad4e37d5d1431a1081be7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB

Git LFS Details

  • SHA256: e1c170311bdc358d5158049f42aa38fba3794c91bcb2d11578f7eb92d924c55c
  • Pointer size: 131 Bytes
  • Size of remote file: 361 kB
trashify_examples/trashify_example_3.jpeg CHANGED

Git LFS Details

  • SHA256: bdb687b82fcc52b3b758143a29fa063356e1655a6dba8f8c9e372ac8255113aa
  • Pointer size: 131 Bytes
  • Size of remote file: 927 kB

Git LFS Details

  • SHA256: 666068a4e4e92384bce54c5f9fa533ccef96da46df065e8760f03d49a04e3fd3
  • Pointer size: 131 Bytes
  • Size of remote file: 278 kB