Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn as nn | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
class ObjectDetection: | |
def __init__(self, ckpt_path): | |
self.test_transform = A.Compose( | |
[ | |
A.Resize(800, 600), | |
A.CLAHE(clip_limit=10, p=1), | |
A.Normalize( | |
[0.29278653, 0.25276296, 0.22975405], | |
[0.22653664, 0.19836408, 0.17775835], | |
), | |
ToTensorV2(), | |
], | |
) | |
self.model = torch.hub.load( | |
"facebookresearch/detr", "detr_resnet50", pretrained=False | |
) | |
in_features = self.model.class_embed.in_features | |
self.model.class_embed = nn.Linear( | |
in_features=in_features, | |
out_features=12, | |
) | |
self.labels = [ | |
"Dog", | |
"Motorbike", | |
"People", | |
"Cat", | |
"Chair", | |
"Table", | |
"Car", | |
"Bicycle", | |
"Bottle", | |
"Bus", | |
"Cup", | |
"Boat", | |
] | |
model_ckpt = torch.load(ckpt_path, map_location=torch.device("cpu")) | |
self.model.load_state_dict(model_ckpt) | |
self.model.eval() | |
def predict(self, img, score_threshold, iou_threshold): | |
img_w, img_h = img.size | |
inp = self.test_transform(image=np.array(img.convert("RGB")))["image"] | |
out = self.model(inp.unsqueeze(0)) | |
probas = out["pred_logits"].softmax(-1)[0, :, :-1] | |
bboxes = [] | |
scores = [] | |
for idx, bbox in enumerate(out["pred_boxes"][0]): | |
if not probas[idx].max().item() >= score_threshold: | |
continue | |
x_c, y_c, w, h = bbox.detach().numpy() | |
x1 = int((x_c - w * 0.5) * img_w) | |
y1 = int((y_c - h * 0.5) * img_h) | |
x2 = int((x_c + w * 0.5) * img_w) | |
y2 = int((y_c + h * 0.5) * img_h) | |
label_idx = probas[idx].argmax().item() | |
label = self.labels[label_idx] + f" {probas[idx].max().item():.2f}" | |
bboxes.append(((x1, y1, x2, y2), label)) | |
scores.append(probas[idx].max().item()) | |
selected_indices = self.non_max_suppression( | |
bboxes, | |
scores, | |
iou_threshold, | |
) | |
bboxes = [bboxes[i] for i in selected_indices] | |
return (img, bboxes) | |
def non_max_suppression(self, boxes, scores, iou_threshold): | |
if len(boxes) == 0: | |
return [] | |
sorted_indices = sorted( | |
range(len(scores)), key=lambda i: scores[i], reverse=True | |
) | |
selected_indices = [] | |
while sorted_indices: | |
current_index = sorted_indices[0] | |
selected_indices.append(current_index) | |
sorted_indices.pop(0) | |
ious = [ | |
self.calculate_iou(boxes[current_index][0], boxes[i][0]) | |
for i in sorted_indices | |
] | |
indices_to_remove = [i for i, iou in enumerate(ious) if iou > iou_threshold] | |
sorted_indices = [ | |
i for j, i in enumerate(sorted_indices) if j not in indices_to_remove | |
] | |
return selected_indices | |
def calculate_iou(self, box1, box2): | |
""" | |
Calculate the Intersection over Union (IoU) of two bounding boxes. | |
Args: | |
box1: [x1, y1, x2, y2] for the first box. | |
box2: [x1, y1, x2, y2] for the second box. | |
Returns: | |
IoU value. | |
""" | |
x1 = max(box1[0], box2[0]) | |
y1 = max(box1[1], box2[1]) | |
x2 = min(box1[2], box2[2]) | |
y2 = min(box1[3], box2[3]) | |
intersection_area = max(0, x2 - x1) * max(0, y2 - y1) | |
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
iou = intersection_area / (box1_area + box2_area - intersection_area) | |
return iou | |
model_path = hf_hub_download( | |
repo_id="SatwikKambham/detr_low_light", | |
filename="detr.pt", | |
) | |
detector = ObjectDetection(ckpt_path=model_path) | |
iface = gr.Interface( | |
fn=detector.predict, | |
inputs=[ | |
gr.Image(type="pil", label="Input"), | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.05, | |
label="Score Threshold", | |
), | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.1, | |
label="IoU Threshold", | |
), | |
], | |
outputs=gr.AnnotatedImage( | |
height=600, | |
width=800, | |
), | |
) | |
iface.launch() | |