File size: 4,205 Bytes
74963fc
 
 
 
891c708
 
74963fc
 
891c708
74963fc
 
 
891c708
74963fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a1f4c9
74963fc
4a1f4c9
74963fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9f38d5
 
 
 
 
 
 
 
74963fc
4a1f4c9
c9f38d5
 
 
 
 
 
 
74963fc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.optim as optim
import lightning.pytorch as pl
from torchvision import transforms
from custom_library.utils import cells_to_bboxes, non_max_suppression
from custom_library import config
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from custom_library.lightning_model import YOLOv3Lightning
import cv2
import numpy as np
from pytorch_grad_cam.utils.image import show_cam_on_image
from custom_library.gradio_utils import  draw_predictions, YoloCAM
import gradio as gr
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2


model = YOLOv3Lightning(config=config)
model.load_state_dict(torch.load("custom_yolo_model.pth", map_location=torch.device('cpu')), strict=False)
model.setup(stage="test")
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]

scaled_anchors = (torch.tensor(config.ANCHORS)* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)).to(config.DEVICE)

transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=config.IMAGE_SIZE),
        A.PadIfNeeded(
            min_height=config.IMAGE_SIZE, min_width=config.IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
        ),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ToTensorV2(),
    ],
)

def model_inference(image, iou_threshold=0.5, threshold=0.4, show_cam="No", transparency=0.5, target_layer=-2):
    # Transforming image
    transformed_image = transforms(image=image)["image"].unsqueeze(0)
    output = model(transformed_image)
    # Selecting layer for gradCAM
    if target_layer == -2:
      layer = [model.model.layers[-3]]
    else:
      layer = [model.model.layers[-2]]

    cam = YoloCAM(model=model, target_layers=layer, use_cuda=False)

    bboxes = [[] for _ in range(1)]
    for i in range(3):
        batch_size, A, S, _, _ = output[i].shape
        anchor = scaled_anchors[i]
        boxes_scale_i = cells_to_bboxes(output[i], anchor, S=S, is_preds=True)
        for idx, (box) in enumerate(boxes_scale_i):
            bboxes[idx] += box

    nms_boxes = non_max_suppression(bboxes[0], iou_threshold=iou_threshold, threshold=threshold, box_format="midpoint")
    plot_img = draw_predictions(image.copy(), nms_boxes, class_labels=config.PASCAL_CLASSES)
    if  show_cam == "No":
      return [plot_img]
    else:
      grayscale_cam = cam(transformed_image, scaled_anchors)[0, :, :]
      img = cv2.resize(image, (416, 416))
      img = np.float32(img) / 255
      cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency)
      return [plot_img, cam_image]


title = "Custom YOLOv3"
description = "Pytorch Lightning implemetation of YOLOv3 on Pascal VOC dataset.\
                        Supported classes are aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, dining table, dog, horse, motorbike, person, potted plant, sheep, sofa, train, and TV/monitor."

examples = [["examples/example1.jpg"],
            ["examples/example2.jpg"],
            ["examples/example3.jpg"],
            ["examples/example4.jpg"],
            ["examples/example5.jpg"],
            ["examples/example6.jpg"],
            ["examples/example7.jpg"],
            ["examples/example8.jpg"]]

demo = gr.Interface(model_inference, inputs=[gr.Image(label="Input an image"),
                                          gr.Slider(0, 1, value=0.5, label="IOU Threshold"),
                                          gr.Slider(0, 1, value=0.4, label="Threshold"),
                                          gr.Radio(["Yes", "No"], value="No"  , label="Show GradCAM outputs"),
                                          gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"),
                                          gr.Slider(-2, -1, value=-1, step=1, label="Which Layer?")],
                                  outputs=[gr.Gallery(label="Model Outputs", rows=2, columns=1, object_fit="contain", height="auto")],
                                  title=title, description=description, examples=examples)
demo.launch()