SatwikKambham commited on
Commit
abf616c
·
1 Parent(s): c5253ed

Added gradio app

Browse files
Files changed (2) hide show
  1. app.py +162 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ import albumentations as A
7
+ from albumentations.pytorch import ToTensorV2
8
+
9
+ from huggingface_hub import hf_hub_download
10
+ import gradio as gr
11
+
12
+
13
+ class ObjectDetection:
14
+ def __init__(self, ckpt_path):
15
+ self.test_transform = A.Compose(
16
+ [
17
+ A.Resize(800, 600),
18
+ A.CLAHE(clip_limit=10, p=1),
19
+ A.Normalize(
20
+ [0.29278653, 0.25276296, 0.22975405],
21
+ [0.22653664, 0.19836408, 0.17775835],
22
+ ),
23
+ ToTensorV2(),
24
+ ],
25
+ )
26
+
27
+ self.model = torch.hub.load(
28
+ "facebookresearch/detr", "detr_resnet50", pretrained=False
29
+ )
30
+ in_features = self.model.class_embed.in_features
31
+ self.model.class_embed = nn.Linear(
32
+ in_features=in_features,
33
+ out_features=12,
34
+ )
35
+ self.labels = [
36
+ "Dog",
37
+ "Motorbike",
38
+ "People",
39
+ "Cat",
40
+ "Chair",
41
+ "Table",
42
+ "Car",
43
+ "Bicycle",
44
+ "Bottle",
45
+ "Bus",
46
+ "Cup",
47
+ "Boat",
48
+ ]
49
+
50
+ model_ckpt = torch.load(ckpt_path, map_location=torch.device("cpu"))
51
+ self.model.load_state_dict(model_ckpt)
52
+ self.model.eval()
53
+
54
+ def predict(self, img, score_threshold, iou_threshold):
55
+ img_w, img_h = img.size
56
+ inp = self.test_transform(image=np.array(img.convert("RGB")))["image"]
57
+ out = self.model(inp.unsqueeze(0))
58
+ probas = out["pred_logits"].softmax(-1)[0, :, :-1]
59
+ bboxes = []
60
+ scores = []
61
+ for idx, bbox in enumerate(out["pred_boxes"][0]):
62
+ if not probas[idx].max().item() >= score_threshold:
63
+ continue
64
+ x_c, y_c, w, h = bbox.detach().numpy()
65
+ x1 = int((x_c - w * 0.5) * img_w)
66
+ y1 = int((y_c - h * 0.5) * img_h)
67
+ x2 = int((x_c + w * 0.5) * img_w)
68
+ y2 = int((y_c + h * 0.5) * img_h)
69
+ label_idx = probas[idx].argmax().item()
70
+ label = self.labels[label_idx] + f" {probas[idx].max().item():.2f}"
71
+ bboxes.append(((x1, y1, x2, y2), label))
72
+ scores.append(probas[idx].max().item())
73
+ selected_indices = self.non_max_suppression(
74
+ bboxes,
75
+ scores,
76
+ iou_threshold,
77
+ )
78
+ bboxes = [bboxes[i] for i in selected_indices]
79
+ return (img, bboxes)
80
+
81
+ def non_max_suppression(self, boxes, scores, iou_threshold):
82
+ if len(boxes) == 0:
83
+ return []
84
+
85
+ sorted_indices = sorted(
86
+ range(len(scores)), key=lambda i: scores[i], reverse=True
87
+ )
88
+ selected_indices = []
89
+
90
+ while sorted_indices:
91
+ current_index = sorted_indices[0]
92
+ selected_indices.append(current_index)
93
+ sorted_indices.pop(0)
94
+
95
+ ious = [
96
+ self.calculate_iou(boxes[current_index][0], boxes[i][0])
97
+ for i in sorted_indices
98
+ ]
99
+
100
+ indices_to_remove = [i for i, iou in enumerate(ious) if iou > iou_threshold]
101
+
102
+ sorted_indices = [
103
+ i for j, i in enumerate(sorted_indices) if j not in indices_to_remove
104
+ ]
105
+
106
+ return selected_indices
107
+
108
+ def calculate_iou(self, box1, box2):
109
+ """
110
+ Calculate the Intersection over Union (IoU) of two bounding boxes.
111
+
112
+ Args:
113
+ box1: [x1, y1, x2, y2] for the first box.
114
+ box2: [x1, y1, x2, y2] for the second box.
115
+
116
+ Returns:
117
+ IoU value.
118
+ """
119
+ x1 = max(box1[0], box2[0])
120
+ y1 = max(box1[1], box2[1])
121
+ x2 = min(box1[2], box2[2])
122
+ y2 = min(box1[3], box2[3])
123
+
124
+ intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
125
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
126
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
127
+
128
+ iou = intersection_area / (box1_area + box2_area - intersection_area)
129
+
130
+ return iou
131
+
132
+
133
+ model_path = hf_hub_download(
134
+ repo_id="SatwikKambham/detr_low_light",
135
+ filename="detr.pt",
136
+ )
137
+ detector = ObjectDetection(ckpt_path=model_path)
138
+ iface = gr.Interface(
139
+ fn=detector.predict,
140
+ inputs=[
141
+ gr.Image(type="pil", label="Input"),
142
+ gr.Slider(
143
+ minimum=0,
144
+ maximum=1,
145
+ step=0.05,
146
+ value=0.05,
147
+ label="Score Threshold",
148
+ ),
149
+ gr.Slider(
150
+ minimum=0,
151
+ maximum=1,
152
+ step=0.05,
153
+ value=0.1,
154
+ label="IoU Threshold",
155
+ ),
156
+ ],
157
+ outputs=gr.AnnotatedImage(
158
+ height=600,
159
+ width=800,
160
+ ),
161
+ )
162
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ albumentations
4
+ huggingface_hub