import numpy as np import torch import torch.nn as nn import albumentations as A from albumentations.pytorch import ToTensorV2 from huggingface_hub import hf_hub_download import gradio as gr class ObjectDetection: def __init__(self, ckpt_path): self.test_transform = A.Compose( [ A.Resize(800, 600), A.CLAHE(clip_limit=10, p=1), A.Normalize( [0.29278653, 0.25276296, 0.22975405], [0.22653664, 0.19836408, 0.17775835], ), ToTensorV2(), ], ) self.model = torch.hub.load( "facebookresearch/detr", "detr_resnet50", pretrained=False ) in_features = self.model.class_embed.in_features self.model.class_embed = nn.Linear( in_features=in_features, out_features=12, ) self.labels = [ "Dog", "Motorbike", "People", "Cat", "Chair", "Table", "Car", "Bicycle", "Bottle", "Bus", "Cup", "Boat", ] model_ckpt = torch.load(ckpt_path, map_location=torch.device("cpu")) self.model.load_state_dict(model_ckpt) self.model.eval() def predict(self, img): score_threshold, iou_threshold = 0.05, 0.1 img_w, img_h = img.size inp = self.test_transform(image=np.array(img.convert("RGB")))["image"] out = self.model(inp.unsqueeze(0)) probas = out["pred_logits"].softmax(-1)[0, :, :-1] bboxes = [] scores = [] for idx, bbox in enumerate(out["pred_boxes"][0]): if not probas[idx].max().item() >= score_threshold: continue x_c, y_c, w, h = bbox.detach().numpy() x1 = int((x_c - w * 0.5) * img_w) y1 = int((y_c - h * 0.5) * img_h) x2 = int((x_c + w * 0.5) * img_w) y2 = int((y_c + h * 0.5) * img_h) label_idx = probas[idx].argmax().item() label = self.labels[label_idx] + f" {probas[idx].max().item():.2f}" bboxes.append(((x1, y1, x2, y2), label)) scores.append(probas[idx].max().item()) selected_indices = self.non_max_suppression( bboxes, scores, iou_threshold, ) bboxes = [bboxes[i] for i in selected_indices] return (img, bboxes) def non_max_suppression(self, boxes, scores, iou_threshold): if len(boxes) == 0: return [] sorted_indices = sorted( range(len(scores)), key=lambda i: scores[i], reverse=True ) selected_indices = [] while sorted_indices: current_index = sorted_indices[0] selected_indices.append(current_index) sorted_indices.pop(0) ious = [ self.calculate_iou(boxes[current_index][0], boxes[i][0]) for i in sorted_indices ] indices_to_remove = [i for i, iou in enumerate(ious) if iou > iou_threshold] sorted_indices = [ i for j, i in enumerate(sorted_indices) if j not in indices_to_remove ] return selected_indices def calculate_iou(self, box1, box2): """ Calculate the Intersection over Union (IoU) of two bounding boxes. Args: box1: [x1, y1, x2, y2] for the first box. box2: [x1, y1, x2, y2] for the second box. Returns: IoU value. """ x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) intersection_area = max(0, x2 - x1) * max(0, y2 - y1) box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) iou = intersection_area / (box1_area + box2_area - intersection_area) return iou model_path = hf_hub_download( repo_id="SatwikKambham/detr_low_light", filename="detr.pt", ) detector = ObjectDetection(ckpt_path=model_path) iface = gr.Interface( fn=detector.predict, inputs=[ gr.Image( type="pil", label="Input", height=400, ), ], outputs=gr.AnnotatedImage( height=400, ), examples="Examples", ) iface.launch()