vsaez commited on
Commit
221337b
·
verified ·
1 Parent(s): f50c0a6

Add first version of object detection

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw
3
+ from transformers import DetrImageProcessor, DetrForObjectDetection
4
+ import torch
5
+
6
+ # Load DETR model and processor from Hugging Face
7
+ model_name = "facebook/detr-resnet-50"
8
+ processor = DetrImageProcessor.from_pretrained(model_name)
9
+ model = DetrForObjectDetection.from_pretrained(model_name)
10
+
11
+ # Main function: takes an image and returns it with boxes and labels
12
+ def detect_objects(image):
13
+ inputs = processor(images=image, return_tensors="pt")
14
+ outputs = model(**inputs)
15
+
16
+ # Convert model output to usable detection results
17
+ target_sizes = torch.tensor([image.size[::-1]])
18
+ results = processor.post_process_object_detection(
19
+ outputs, threshold=0.9, target_sizes=target_sizes
20
+ )[0]
21
+
22
+ # Draw bounding boxes and labels on a copy of the image
23
+ image_with_boxes = image.copy()
24
+ draw = ImageDraw.Draw(image_with_boxes)
25
+
26
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
27
+ box = [round(x, 2) for x in box.tolist()]
28
+ draw.rectangle(box, outline="red", width=3)
29
+ label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
30
+ draw.text((box[0], box[1]), label_text, fill="white")
31
+
32
+ return image_with_boxes
33
+
34
+ # Gradio interface
35
+ app = gr.Interface(
36
+ fn=detect_objects,
37
+ inputs=gr.Image(type="pil"),
38
+ outputs=gr.Image()
39
+ )
40
+
41
+ # Run app
42
+ if __name__ == "__main__":
43
+ app.launch()