Sijuade commited on
Commit
a00793d
·
1 Parent(s): 4532c8c

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +46 -0
  2. config.py +113 -0
  3. dataset.py +195 -0
  4. gradio_utils.py +211 -0
  5. requirements.txt +11 -0
  6. utils.py +525 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import numpy as np
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import torch, torchvision
6
+ from torchvision import transforms
7
+ from gradio_utils import (
8
+ generate_html,
9
+ get_examples,
10
+ upload_image_inference
11
+ )
12
+
13
+
14
+ show_label = True
15
+ examples = get_examples()
16
+ iou_thresh, thresh = 0.5, 0.6
17
+
18
+ with gr.Blocks() as gradcam:
19
+ gr.HTML(value=generate_html, show_label=show_label)
20
+
21
+ with gr.Row():
22
+ upload_input = [gr.Image(shape=(config.INFERENCE_IMAGE_SIZE,
23
+ config.INFERENCE_IMAGE_SIZE)),
24
+ gr.Slider(0, 1, label='Transparency', value=0.6)]
25
+
26
+ with gr.Row():
27
+ upload_output = [
28
+ gr.AnnotatedImage(label='BBox Prediction',
29
+ height=config.INFERENCE_IMAGE_SIZE,
30
+ width=config.INFERENCE_IMAGE_SIZE),
31
+ gr.Gallery(label="Grad-CAM Output",
32
+ show_label=True, min_width=120)]
33
+
34
+
35
+ with gr.Row():
36
+ inference_button = gr.Button("Perform Inference")
37
+ inference_button.click(upload_image_inference,
38
+ inputs=upload_input,
39
+ outputs=upload_output)
40
+
41
+ with gr.Row():
42
+ gr.Examples(examples=examples, inputs=upload_input, outputs=upload_output, fn=upload_image_inference, cache_examples=True,)
43
+
44
+
45
+
46
+ gradcam.launch(debug=True)
config.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import cv2
3
+ import torch
4
+
5
+ from albumentations.pytorch import ToTensorV2
6
+ from utils import seed_everything
7
+
8
+ DATASET = 'PASCAL_VOC'
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ # seed_everything() # If you want deterministic behavior
11
+ NUM_WORKERS = 0
12
+ BATCH_SIZE = 2
13
+ DIV = 32
14
+ IMAGE_SIZES = [416, 416, 416, 608, 608]
15
+ S = [[x//DIV, x//DIV*2, x//DIV*4] for x in IMAGE_SIZES]
16
+ NUM_CLASSES = 20
17
+ LEARNING_RATE = 1e-5
18
+ WEIGHT_DECAY = 1e-4
19
+ NUM_EPOCHS = 10
20
+ CONF_THRESHOLD = 0.05
21
+ MAP_IOU_THRESH = 0.5
22
+ NMS_IOU_THRESH = 0.45
23
+ PIN_MEMORY = True
24
+ LOAD_MODEL = False
25
+ SAVE_MODEL = True
26
+ CHECKPOINT_FILE = "checkpoint.pth.tar"
27
+ IMG_DIR = DATASET + "/images/"
28
+ LABEL_DIR = DATASET + "/labels/"
29
+ MOSAIC_PROB = 0.75
30
+ INFERENCE_IMAGE_SIZE = 416
31
+
32
+ ANCHORS = [
33
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
34
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
35
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
36
+ ] # Note these have been rescaled to be between [0, 1]
37
+
38
+ means = [0.45484068, 0.43406072, 0.40103856]
39
+ stds = [0.23936155, 0.23471538, 0.23876129]
40
+
41
+ scale = 1.1
42
+ def train_transform(IMAGE_SIZE):
43
+ train_transforms = A.Compose(
44
+ [
45
+ A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)),
46
+ A.PadIfNeeded(
47
+ min_height=int(IMAGE_SIZE * scale),
48
+ min_width=int(IMAGE_SIZE * scale),
49
+ border_mode=cv2.BORDER_CONSTANT,
50
+ ),
51
+ A.Rotate(limit = 10, interpolation=1, border_mode=4),
52
+ A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE),
53
+ A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
54
+ A.OneOf(
55
+ [
56
+ A.ShiftScaleRotate(
57
+ rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
58
+ ),
59
+ # A.Affine(shear=15, p=0.5, mode="constant"),
60
+ ],
61
+ p=1.0,
62
+ ),
63
+ A.HorizontalFlip(p=0.5),
64
+ A.Blur(p=0.1),
65
+ A.CLAHE(p=0.1),
66
+ A.Posterize(p=0.1),
67
+ A.ToGray(p=0.1),
68
+ A.ChannelShuffle(p=0.05),
69
+ A.Normalize(mean=means, std=stds, max_pixel_value=255,),
70
+ ToTensorV2(),
71
+ ],
72
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],),
73
+ )
74
+
75
+ return(train_transforms)
76
+
77
+ def test_transform(IMAGE_SIZE=416):
78
+ test_transforms = A.Compose(
79
+ [
80
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
81
+ A.PadIfNeeded(
82
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
83
+ ),
84
+ A.Normalize(mean=means, std=stds, max_pixel_value=255,),
85
+ ToTensorV2(),
86
+ ],
87
+ bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]),
88
+ )
89
+ return(test_transforms)
90
+
91
+
92
+ PASCAL_CLASSES = [
93
+ "aeroplane",
94
+ "bicycle",
95
+ "bird",
96
+ "boat",
97
+ "bottle",
98
+ "bus",
99
+ "car",
100
+ "cat",
101
+ "chair",
102
+ "cow",
103
+ "diningtable",
104
+ "dog",
105
+ "horse",
106
+ "motorbike",
107
+ "person",
108
+ "pottedplant",
109
+ "sheep",
110
+ "sofa",
111
+ "train",
112
+ "tvmonitor"
113
+ ]
dataset.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Creates a Pytorch dataset to load the Pascal VOC & MS COCO datasets
3
+ """
4
+ #468 520
5
+ import config
6
+ import numpy as np
7
+ import os
8
+ import pandas as pd
9
+ import torch
10
+ from utils import xywhn2xyxy, xyxy2xywhn
11
+ import random
12
+
13
+ from PIL import Image, ImageFile
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from utils import (
16
+ cells_to_bboxes,
17
+ iou_width_height as iou,
18
+ non_max_suppression as nms,
19
+ plot_image
20
+ )
21
+
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
23
+
24
+
25
+ class YOLODataset(Dataset):
26
+ def __init__(
27
+ self,
28
+ csv_file,
29
+ img_dir,
30
+ label_dir,
31
+ anchors,
32
+ C=20,
33
+ transform=None,
34
+ train=True
35
+ ):
36
+ self.annotations = pd.read_csv(csv_file)
37
+ self.img_dir = img_dir
38
+ self.label_dir = label_dir
39
+ self.image_size = 416
40
+ self.transform = transform
41
+ self.S = [13, 26, 52]
42
+ self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2]) # for all 3 scales
43
+ self.num_anchors = self.anchors.shape[0]
44
+ self.num_anchors_per_scale = self.num_anchors // 3
45
+ self.C = C
46
+ self.ignore_iou_thresh = 0.5
47
+ self.mosaic_border = [self.image_size//2, self.image_size//2]
48
+ self.train_data = train
49
+
50
+ def __len__(self):
51
+ return len(self.annotations)
52
+
53
+ def set_image_size(self, size_idx):
54
+ self.image_size = config.IMAGE_SIZES[size_idx]
55
+ self.S = config.S[size_idx]
56
+ self.mosaic_border = [self.image_size // 2, self.image_size // 2]
57
+
58
+
59
+ def load_mosaic(self, image_size, index):
60
+ # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
61
+ labels4 = []
62
+ s = image_size
63
+ yc, xc = (int(random.uniform(x, 2 * s - x)) for x in self.mosaic_border) # mosaic center x, y
64
+ indices = [index] + random.choices(range(len(self)), k=3) # 3 additional image indices
65
+ random.shuffle(indices)
66
+ for i, index in enumerate(indices):
67
+ # Load image
68
+ label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
69
+ bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist()
70
+ img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
71
+ img = np.array(Image.open(img_path).convert("RGB"))
72
+
73
+
74
+ h, w = img.shape[0], img.shape[1]
75
+ labels = np.array(bboxes)
76
+
77
+ # place img in img4
78
+ if i == 0: # top left
79
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
80
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
81
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
82
+ elif i == 1: # top right
83
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
84
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
85
+ elif i == 2: # bottom left
86
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
87
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
88
+ elif i == 3: # bottom right
89
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
90
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
91
+
92
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
93
+ padw = x1a - x1b
94
+ padh = y1a - y1b
95
+
96
+ # Labels
97
+ if labels.size:
98
+ labels[:, :-1] = xywhn2xyxy(labels[:, :-1], w, h, padw, padh) # normalized xywh to pixel xyxy format
99
+ labels4.append(labels)
100
+
101
+ # Concat/clip labels
102
+ labels4 = np.concatenate(labels4, 0)
103
+ for x in (labels4[:, :-1],):
104
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
105
+ # img4, labels4 = replicate(img4, labels4) # replicate
106
+ labels4[:, :-1] = xyxy2xywhn(labels4[:, :-1], 2 * s, 2 * s)
107
+ labels4[:, :-1] = np.clip(labels4[:, :-1], 0, 1)
108
+ labels4 = labels4[labels4[:, 2] > 0]
109
+ labels4 = labels4[labels4[:, 3] > 0]
110
+ return img4, labels4
111
+
112
+ def __getitem__(self, index):
113
+
114
+ if self.train_data and np.random.random() <= config.MOSAIC_PROB:
115
+ image, bboxes = self.load_mosaic(self.image_size, index)
116
+ else:
117
+ label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
118
+ bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist()
119
+ img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
120
+ image = np.array(Image.open(img_path).convert("RGB"))
121
+
122
+ if self.transform:
123
+ transforms = self.transform(self.image_size) if self.train_data else self.transform()
124
+ augmentations = transforms(image=image, bboxes=bboxes)
125
+ image = augmentations["image"]
126
+ bboxes = augmentations["bboxes"]
127
+
128
+ # Below assumes 3 scale predictions (as paper) and same num of anchors per scale
129
+ targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]
130
+ for box in bboxes:
131
+ iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)
132
+ anchor_indices = iou_anchors.argsort(descending=True, dim=0)
133
+ x, y, width, height, class_label = box
134
+ has_anchor = [False] * 3 # each scale should have one anchor
135
+ for anchor_idx in anchor_indices:
136
+ scale_idx = anchor_idx // self.num_anchors_per_scale
137
+ anchor_on_scale = anchor_idx % self.num_anchors_per_scale
138
+ S = self.S[scale_idx]
139
+ i, j = int(S * y), int(S * x) # which cell
140
+ anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
141
+ if not anchor_taken and not has_anchor[scale_idx]:
142
+ targets[scale_idx][anchor_on_scale, i, j, 0] = 1
143
+ x_cell, y_cell = S * x - j, S * y - i # both between [0,1]
144
+ width_cell, height_cell = (
145
+ width * S,
146
+ height * S,
147
+ ) # can be greater than 1 since it's relative to cell
148
+ box_coordinates = torch.tensor(
149
+ [x_cell, y_cell, width_cell, height_cell]
150
+ )
151
+ targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
152
+ targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
153
+ has_anchor[scale_idx] = True
154
+
155
+ elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
156
+ targets[scale_idx][anchor_on_scale, i, j, 0] = -1 # ignore prediction
157
+
158
+ return image, tuple(targets)
159
+
160
+
161
+ def test():
162
+ anchors = config.ANCHORS
163
+
164
+ transform = config.test_transform
165
+
166
+ dataset = YOLODataset(
167
+ "COCO/train.csv",
168
+ "COCO/images/images/",
169
+ "COCO/labels/labels_new/",
170
+ S=[13, 26, 52],
171
+ anchors=anchors,
172
+ transform=transform,
173
+ )
174
+ S = [13, 26, 52]
175
+ scaled_anchors = torch.tensor(anchors) / (
176
+ 1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
177
+ )
178
+ loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
179
+ for x, y in loader:
180
+ boxes = []
181
+
182
+ for i in range(y[0].shape[1]):
183
+ anchor = scaled_anchors[i]
184
+ print(anchor.shape)
185
+ print(y[i].shape)
186
+ boxes += cells_to_bboxes(
187
+ y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
188
+ )[0]
189
+ boxes = nms(boxes, iou_threshold=1, threshold=0.7, box_format="midpoint")
190
+ print(boxes)
191
+ plot_image(x[0].permute(1, 2, 0).to("cpu"), boxes)
192
+
193
+
194
+ if __name__ == "__main__":
195
+ test()
gradio_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from albumentations.pytorch import ToTensorV2
4
+ import albumentations as A
5
+ import cv2
6
+ import glob2
7
+ import config
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.patches as patches
11
+ from lightning_utils import YOLOv3Lightning
12
+ from pytorch_grad_cam import GradCAM, EigenCAM
13
+ from pytorch_grad_cam.utils.image import show_cam_on_image
14
+ from pytorch_grad_cam.utils.model_targets import FasterRCNNBoxScoreTarget
15
+
16
+ from utils import cells_to_bboxes, non_max_suppression
17
+
18
+
19
+ cmap = plt.get_cmap("tab20b")
20
+ class_labels = config.PASCAL_CLASSES
21
+ height, width = config.INFERENCE_IMAGE_SIZE, config.INFERENCE_IMAGE_SIZE
22
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
23
+
24
+ icons = [
25
+ 'flight', 'pedal_bike', 'flutter_dash', 'sailing',
26
+ 'liquor', 'directions_bus', 'directions_car',
27
+ 'pets', "chair", 'pets', 'table_restaurant',
28
+ 'pets', 'bedroom_baby', 'motorcycle', 'person', 'yard',
29
+ 'kebab_dining', 'chair', "train", "tvmonitor"]
30
+
31
+ icons_mapping = {config.PASCAL_CLASSES[i]:icons[i] for i in range(len(icons))}
32
+
33
+ model = YOLOv3Lightning()
34
+ model = model.load_from_checkpoint('YoLoV3Model.ckpt',
35
+ map_location=torch.device('cpu'))
36
+ model.eval()
37
+
38
+ scaled_anchors = (
39
+ torch.tensor(config.ANCHORS)
40
+ * torch.tensor(config.S[0]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
41
+ ).to(config.DEVICE)
42
+
43
+ def get_examples():
44
+ example_images = glob2.glob('*.jpg')
45
+ example_transparency = [random.choice([0.7, 0.8]) for r in range(len(example_images))]
46
+ examples = [[example_images[i], example_transparency[i]] for i in range(len(example_images))]
47
+ return(examples)
48
+
49
+
50
+
51
+ def yolov3_reshape_transform(x):
52
+ activations = []
53
+ size = x[0].size()[2:4]
54
+
55
+ for x_item in x:
56
+ x_permute = x_item.permute(0, 1, 4, 2, 3 )
57
+ x_permute = x_permute.reshape((x_permute.shape[0],
58
+ x_permute.shape[1]*x_permute.shape[2],
59
+ *x_permute.shape[3:]))
60
+ activations.append(torch.nn.functional.interpolate(torch.abs(x_permute), size, mode='bilinear'))
61
+
62
+ activations = torch.cat(activations, axis=1)
63
+
64
+ return(activations)
65
+
66
+
67
+ def infer_transform(IMAGE_SIZE=config.INFERENCE_IMAGE_SIZE):
68
+ transforms = A.Compose(
69
+ [
70
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
71
+ A.PadIfNeeded(
72
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
73
+ ),
74
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
75
+ ToTensorV2(),
76
+ ]
77
+ )
78
+ return(transforms)
79
+
80
+ def generate_html():
81
+ # Start the HTML string with some style and the Material Icons stylesheet
82
+ classes = config.PASCAL_CLASSES
83
+ html_string = """
84
+ <link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
85
+ <style>
86
+ .title {
87
+ font-size: 24px;
88
+ font-weight: bold;
89
+ text-align: center;
90
+ margin-bottom: 20px;
91
+ color: #4a4a4a;
92
+ }
93
+ .subtitle {
94
+ font-size: 18px;
95
+ text-align: center;
96
+ margin-bottom: 10px;
97
+ color: #7a7a7a;
98
+ }
99
+ .class-container {
100
+ display: flex;
101
+ flex-wrap: wrap;
102
+ justify-content: center;
103
+ align-items: center;
104
+ padding: 20px;
105
+ border: 2px solid #e0e0e0;
106
+ border-radius: 10px;
107
+ background-color: #f5f5f5;
108
+ }
109
+ .class-item {
110
+ display: inline-flex; /* Changed from flex to inline-flex */
111
+ align-items: center;
112
+ margin: 5px 10px;
113
+ padding: 5px 8px; /* Adjusted padding */
114
+ border: 1px solid #d1d1d1;
115
+ border-radius: 20px;
116
+ background-color: #ffffff;
117
+ font-weight: bold;
118
+ color: #333;
119
+ box-shadow: 2px 2px 5px rgba(0, 0, 0, 0.1);
120
+ transition: transform 0.2s, box-shadow 0.2s;
121
+ }
122
+ .class-item:hover {
123
+ transform: scale(1.05);
124
+ box-shadow: 2px 2px 10px rgba(0, 0, 0, 0.2);
125
+ background-color: #e7e7e7;
126
+ }
127
+ .material-icons {
128
+ margin-right: 8px;
129
+ }
130
+ </style>
131
+ <div class="title">Object Detection Prediction & Grad-Cam for YOLOv3</div>
132
+ <div class="subtitle">Supported Classes</div>
133
+ <div class="class-container">
134
+ """
135
+
136
+ # Loop through each class and add it to the HTML string with its corresponding icon
137
+ for class_name in classes:
138
+ icon_name = class_name.lower() # Assuming the icon name is the lowercase version of the class name
139
+ icon_name = icons_mapping[icon_name]
140
+ html_string += f'<div class="class-item"><i class="material-icons">{icon_name}</i>{class_name}</div>'
141
+
142
+ # Close the HTML string
143
+ html_string += "</div>"
144
+
145
+ return html_string
146
+
147
+
148
+
149
+ def upload_image_inference(img, transparency):
150
+ bboxes = [[] for _ in range(1)]
151
+ nms_boxes_output, annotations = [], []
152
+ img_copy = img.copy()
153
+
154
+ transform = infer_transform()
155
+ img = transform(image=img)['image'].unsqueeze(0)
156
+
157
+ out = model(img)
158
+
159
+ for i in range(3):
160
+ batch_size, A, S, _, _ = out[i].shape
161
+ anchor = scaled_anchors[i]
162
+ boxes_scale_i = cells_to_bboxes(
163
+ out[i], anchor, S=S, is_preds=True
164
+ )
165
+
166
+ for idx, (box) in enumerate(boxes_scale_i):
167
+ bboxes[idx] += box
168
+
169
+ for i in range(img.shape[0]):
170
+ iou_thresh, thresh = 0.5, 0.6
171
+ nms_boxes = non_max_suppression(
172
+ bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
173
+ )
174
+
175
+ nms_boxes_output.append(nms_boxes)
176
+
177
+ for box in nms_boxes_output[0]:
178
+ class_prediction = int(box[0])
179
+ box = box[2:]
180
+
181
+ upper_left_x = box[0] - box[2] / 2
182
+ upper_left_y = box[1] - box[3] / 2
183
+ rect = patches.Rectangle(
184
+ (upper_left_x * width, upper_left_y * height),
185
+ box[2] * width,
186
+ box[3] * height,
187
+ linewidth=2,
188
+ edgecolor=colors[class_prediction],
189
+ facecolor="none",
190
+ )
191
+ rect = rect.get_bbox().get_points()
192
+ annotations.append([rect[0].astype(int).tolist()+rect[1].astype(int).tolist(),
193
+ config.PASCAL_CLASSES[class_prediction]])
194
+
195
+
196
+ objs = [b[1] for b in nms_boxes_output[0]]
197
+ bbox_coord = [b[2:] for b in nms_boxes_output[0]]
198
+ targets = [FasterRCNNBoxScoreTarget(objs, bbox_coord)]
199
+
200
+ cam = EigenCAM(model=model,
201
+ target_layers=[model.model],
202
+ reshape_transform=yolov3_reshape_transform)
203
+
204
+ grayscale_cam = cam(input_tensor=img, targets=targets)
205
+ grayscale_cam = grayscale_cam[0, :]
206
+
207
+ visualization = show_cam_on_image(img_copy/255, grayscale_cam, use_rgb=False, image_weight=transparency)
208
+
209
+ return([[img_copy, annotations],
210
+ [grayscale_cam, visualization]])
211
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ grad-cam
2
+ gradio
3
+ torch
4
+ torchvision
5
+ pillow
6
+ numpy
7
+ pytorch_lightning
8
+ torchmetrics
9
+ albumentations
10
+ opencv-python
11
+ glob2
utils.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ import numpy as np
5
+ import os
6
+ import random
7
+ import torch
8
+ from collections import Counter
9
+ from tqdm import tqdm
10
+
11
+
12
+
13
+ def iou_width_height(boxes1, boxes2):
14
+ """
15
+ Parameters:
16
+ boxes1 (tensor): width and height of the first bounding boxes
17
+ boxes2 (tensor): width and height of the second bounding boxes
18
+ Returns:
19
+ tensor: Intersection over union of the corresponding boxes
20
+ """
21
+ intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
22
+ boxes1[..., 1], boxes2[..., 1]
23
+ )
24
+ union = (
25
+ boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
26
+ )
27
+ return intersection / union
28
+
29
+
30
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
31
+ """
32
+ Video explanation of this function:
33
+ https://youtu.be/XXYG5ZWtjj0
34
+
35
+ This function calculates intersection over union (iou) given pred boxes
36
+ and target boxes.
37
+
38
+ Parameters:
39
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
40
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
41
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
42
+
43
+ Returns:
44
+ tensor: Intersection over union for all examples
45
+ """
46
+
47
+ if box_format == "midpoint":
48
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
49
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
50
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
51
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
52
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
53
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
54
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
55
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
56
+
57
+ if box_format == "corners":
58
+ box1_x1 = boxes_preds[..., 0:1]
59
+ box1_y1 = boxes_preds[..., 1:2]
60
+ box1_x2 = boxes_preds[..., 2:3]
61
+ box1_y2 = boxes_preds[..., 3:4]
62
+ box2_x1 = boxes_labels[..., 0:1]
63
+ box2_y1 = boxes_labels[..., 1:2]
64
+ box2_x2 = boxes_labels[..., 2:3]
65
+ box2_y2 = boxes_labels[..., 3:4]
66
+
67
+ x1 = torch.max(box1_x1, box2_x1)
68
+ y1 = torch.max(box1_y1, box2_y1)
69
+ x2 = torch.min(box1_x2, box2_x2)
70
+ y2 = torch.min(box1_y2, box2_y2)
71
+
72
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
73
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
74
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
75
+
76
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
77
+
78
+
79
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
80
+ """
81
+ Video explanation of this function:
82
+ https://youtu.be/YDkjWEN8jNA
83
+
84
+ Does Non Max Suppression given bboxes
85
+
86
+ Parameters:
87
+ bboxes (list): list of lists containing all bboxes with each bboxes
88
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
89
+ iou_threshold (float): threshold where predicted bboxes is correct
90
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
91
+ box_format (str): "midpoint" or "corners" used to specify bboxes
92
+
93
+ Returns:
94
+ list: bboxes after performing NMS given a specific IoU threshold
95
+ """
96
+
97
+ assert type(bboxes) == list
98
+
99
+ bboxes = [box for box in bboxes if box[1] > threshold]
100
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
101
+ bboxes_after_nms = []
102
+
103
+ while bboxes:
104
+ chosen_box = bboxes.pop(0)
105
+
106
+ bboxes = [
107
+ box
108
+ for box in bboxes
109
+ if box[0] != chosen_box[0]
110
+ or intersection_over_union(
111
+ torch.tensor(chosen_box[2:]),
112
+ torch.tensor(box[2:]),
113
+ box_format=box_format,
114
+ )
115
+ < iou_threshold
116
+ ]
117
+
118
+ bboxes_after_nms.append(chosen_box)
119
+
120
+ return bboxes_after_nms
121
+
122
+
123
+ def mean_average_precision(
124
+ pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
125
+ ):
126
+ """
127
+ Video explanation of this function:
128
+ https://youtu.be/FppOzcDvaDI
129
+
130
+ This function calculates mean average precision (mAP)
131
+
132
+ Parameters:
133
+ pred_boxes (list): list of lists containing all bboxes with each bboxes
134
+ specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
135
+ true_boxes (list): Similar as pred_boxes except all the correct ones
136
+ iou_threshold (float): threshold where predicted bboxes is correct
137
+ box_format (str): "midpoint" or "corners" used to specify bboxes
138
+ num_classes (int): number of classes
139
+
140
+ Returns:
141
+ float: mAP value across all classes given a specific IoU threshold
142
+ """
143
+
144
+ # list storing all AP for respective classes
145
+ average_precisions = []
146
+
147
+ # used for numerical stability later on
148
+ epsilon = 1e-6
149
+
150
+ for c in range(num_classes):
151
+ detections = []
152
+ ground_truths = []
153
+
154
+ # Go through all predictions and targets,
155
+ # and only add the ones that belong to the
156
+ # current class c
157
+ for detection in pred_boxes:
158
+ if detection[1] == c:
159
+ detections.append(detection)
160
+
161
+ for true_box in true_boxes:
162
+ if true_box[1] == c:
163
+ ground_truths.append(true_box)
164
+
165
+ # find the amount of bboxes for each training example
166
+ # Counter here finds how many ground truth bboxes we get
167
+ # for each training example, so let's say img 0 has 3,
168
+ # img 1 has 5 then we will obtain a dictionary with:
169
+ # amount_bboxes = {0:3, 1:5}
170
+ amount_bboxes = Counter([gt[0] for gt in ground_truths])
171
+
172
+ # We then go through each key, val in this dictionary
173
+ # and convert to the following (w.r.t same example):
174
+ # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
175
+ for key, val in amount_bboxes.items():
176
+ amount_bboxes[key] = torch.zeros(val)
177
+
178
+ # sort by box probabilities which is index 2
179
+ detections.sort(key=lambda x: x[2], reverse=True)
180
+ TP = torch.zeros((len(detections)))
181
+ FP = torch.zeros((len(detections)))
182
+ total_true_bboxes = len(ground_truths)
183
+
184
+ # If none exists for this class then we can safely skip
185
+ if total_true_bboxes == 0:
186
+ continue
187
+
188
+ for detection_idx, detection in enumerate(detections):
189
+ # Only take out the ground_truths that have the same
190
+ # training idx as detection
191
+ ground_truth_img = [
192
+ bbox for bbox in ground_truths if bbox[0] == detection[0]
193
+ ]
194
+
195
+ num_gts = len(ground_truth_img)
196
+ best_iou = 0
197
+
198
+ for idx, gt in enumerate(ground_truth_img):
199
+ iou = intersection_over_union(
200
+ torch.tensor(detection[3:]),
201
+ torch.tensor(gt[3:]),
202
+ box_format=box_format,
203
+ )
204
+
205
+ if iou > best_iou:
206
+ best_iou = iou
207
+ best_gt_idx = idx
208
+
209
+ if best_iou > iou_threshold:
210
+ # only detect ground truth detection once
211
+ if amount_bboxes[detection[0]][best_gt_idx] == 0:
212
+ # true positive and add this bounding box to seen
213
+ TP[detection_idx] = 1
214
+ amount_bboxes[detection[0]][best_gt_idx] = 1
215
+ else:
216
+ FP[detection_idx] = 1
217
+
218
+ # if IOU is lower then the detection is a false positive
219
+ else:
220
+ FP[detection_idx] = 1
221
+
222
+ TP_cumsum = torch.cumsum(TP, dim=0)
223
+ FP_cumsum = torch.cumsum(FP, dim=0)
224
+ recalls = TP_cumsum / (total_true_bboxes + epsilon)
225
+ precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
226
+ precisions = torch.cat((torch.tensor([1]), precisions))
227
+ recalls = torch.cat((torch.tensor([0]), recalls))
228
+ # torch.trapz for numerical integration
229
+ average_precisions.append(torch.trapz(precisions, recalls))
230
+
231
+ return sum(average_precisions) / len(average_precisions)
232
+
233
+
234
+ def plot_image(image, boxes):
235
+ """Plots predicted bounding boxes on the image"""
236
+ cmap = plt.get_cmap("tab20b")
237
+ class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES
238
+ colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
239
+ im = np.array(image)
240
+ height, width, _ = im.shape
241
+
242
+ # Create figure and axes
243
+ fig, ax = plt.subplots(1)
244
+ # Display the image
245
+ ax.imshow(im)
246
+
247
+ # box[0] is x midpoint, box[2] is width
248
+ # box[1] is y midpoint, box[3] is height
249
+
250
+ # Create a Rectangle patch
251
+ for box in boxes:
252
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
253
+ class_pred = box[0]
254
+ box = box[2:]
255
+ upper_left_x = box[0] - box[2] / 2
256
+ upper_left_y = box[1] - box[3] / 2
257
+ rect = patches.Rectangle(
258
+ (upper_left_x * width, upper_left_y * height),
259
+ box[2] * width,
260
+ box[3] * height,
261
+ linewidth=2,
262
+ edgecolor=colors[int(class_pred)],
263
+ facecolor="none",
264
+ )
265
+ # Add the patch to the Axes
266
+ ax.add_patch(rect)
267
+ plt.text(
268
+ upper_left_x * width,
269
+ upper_left_y * height,
270
+ s=class_labels[int(class_pred)],
271
+ color="white",
272
+ verticalalignment="top",
273
+ bbox={"color": colors[int(class_pred)], "pad": 0},
274
+ )
275
+
276
+ plt.show()
277
+
278
+
279
+ def get_evaluation_bboxes(
280
+ loader,
281
+ model,
282
+ iou_threshold,
283
+ anchors,
284
+ threshold,
285
+ box_format="midpoint",
286
+ device="cuda",
287
+ ):
288
+ # make sure model is in eval before get bboxes
289
+ model.eval()
290
+ train_idx = 0
291
+ all_pred_boxes = []
292
+ all_true_boxes = []
293
+ for batch_idx, (x, labels) in enumerate(tqdm(loader)):
294
+ x = x.to(device)
295
+
296
+ with torch.no_grad():
297
+ predictions = model(x)
298
+
299
+ batch_size = x.shape[0]
300
+ bboxes = [[] for _ in range(batch_size)]
301
+ for i in range(3):
302
+ S = predictions[i].shape[2]
303
+ anchor = torch.tensor([*anchors[i]]).to(device) * S
304
+ boxes_scale_i = cells_to_bboxes(
305
+ predictions[i], anchor, S=S, is_preds=True
306
+ )
307
+ for idx, (box) in enumerate(boxes_scale_i):
308
+ bboxes[idx] += box
309
+
310
+ # we just want one bbox for each label, not one for each scale
311
+ true_bboxes = cells_to_bboxes(
312
+ labels[2], anchor, S=S, is_preds=False
313
+ )
314
+
315
+ for idx in range(batch_size):
316
+ nms_boxes = non_max_suppression(
317
+ bboxes[idx],
318
+ iou_threshold=iou_threshold,
319
+ threshold=threshold,
320
+ box_format=box_format,
321
+ )
322
+
323
+ for nms_box in nms_boxes:
324
+ all_pred_boxes.append([train_idx] + nms_box)
325
+
326
+ for box in true_bboxes[idx]:
327
+ if box[1] > threshold:
328
+ all_true_boxes.append([train_idx] + box)
329
+
330
+ train_idx += 1
331
+
332
+ model.train()
333
+ return all_pred_boxes, all_true_boxes
334
+
335
+
336
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
337
+ """
338
+ Scales the predictions coming from the model to
339
+ be relative to the entire image such that they for example later
340
+ can be plotted or.
341
+ INPUT:
342
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
343
+ anchors: the anchors used for the predictions
344
+ S: the number of cells the image is divided in on the width (and height)
345
+ is_preds: whether the input is predictions or the true bounding boxes
346
+ OUTPUT:
347
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
348
+ object score, bounding box coordinates
349
+ """
350
+ BATCH_SIZE = predictions.shape[0]
351
+ num_anchors = len(anchors)
352
+ box_predictions = predictions[..., 1:5]
353
+ if is_preds:
354
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2).to(config.DEVICE)
355
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
356
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
357
+ scores = torch.sigmoid(predictions[..., 0:1])
358
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
359
+ else:
360
+ scores = predictions[..., 0:1]
361
+ best_class = predictions[..., 5:6]
362
+
363
+ cell_indices = (
364
+ torch.arange(S)
365
+ .repeat(predictions.shape[0], 3, S, 1)
366
+ .unsqueeze(-1)
367
+ .to(predictions.device)
368
+ )
369
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
370
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
371
+ w_h = 1 / S * box_predictions[..., 2:4]
372
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
373
+ return converted_bboxes.tolist()
374
+
375
+ def check_class_accuracy(model, loader, threshold):
376
+ model.eval()
377
+ tot_class_preds, correct_class = 0, 0
378
+ tot_noobj, correct_noobj = 0, 0
379
+ tot_obj, correct_obj = 0, 0
380
+
381
+ for idx, (x, y) in enumerate(tqdm(loader)):
382
+ x = x.to(config.DEVICE)
383
+ with torch.no_grad():
384
+ out = model(x)
385
+
386
+ for i in range(3):
387
+ y[i] = y[i].to(config.DEVICE)
388
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
389
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
390
+
391
+ correct_class += torch.sum(
392
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
393
+ )
394
+ tot_class_preds += torch.sum(obj)
395
+
396
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
397
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
398
+ tot_obj += torch.sum(obj)
399
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
400
+ tot_noobj += torch.sum(noobj)
401
+
402
+ print(f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%")
403
+ print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%")
404
+ print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%")
405
+ model.train()
406
+
407
+
408
+ def get_mean_std(loader):
409
+ # var[X] = E[X**2] - E[X]**2
410
+ channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
411
+
412
+ for data, _ in tqdm(loader):
413
+ channels_sum += torch.mean(data, dim=[0, 2, 3])
414
+ channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
415
+ num_batches += 1
416
+
417
+ mean = channels_sum / num_batches
418
+ std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
419
+
420
+ return mean, std
421
+
422
+
423
+ def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
424
+ print("=> Saving checkpoint")
425
+ checkpoint = {
426
+ "state_dict": model.state_dict(),
427
+ "optimizer": optimizer.state_dict(),
428
+ }
429
+ torch.save(checkpoint, filename)
430
+
431
+
432
+ def load_checkpoint(checkpoint_file, model, optimizer, lr):
433
+ print("=> Loading checkpoint")
434
+ checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
435
+ model.load_state_dict(checkpoint["state_dict"])
436
+ optimizer.load_state_dict(checkpoint["optimizer"])
437
+
438
+ # If we don't do this then it will just have learning rate of old checkpoint
439
+ # and it will lead to many hours of debugging \:
440
+ for param_group in optimizer.param_groups:
441
+ param_group["lr"] = lr
442
+
443
+
444
+ def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
445
+ model.eval()
446
+ x, y = next(iter(loader))
447
+ x = x.to("cuda")
448
+ with torch.no_grad():
449
+ out = model(x)
450
+ bboxes = [[] for _ in range(x.shape[0])]
451
+ for i in range(3):
452
+ batch_size, A, S, _, _ = out[i].shape
453
+ anchor = anchors[i]
454
+ boxes_scale_i = cells_to_bboxes(
455
+ out[i], anchor, S=S, is_preds=True
456
+ )
457
+ for idx, (box) in enumerate(boxes_scale_i):
458
+ bboxes[idx] += box
459
+
460
+ model.train()
461
+
462
+ for i in range(batch_size//4):
463
+ nms_boxes = non_max_suppression(
464
+ bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
465
+ )
466
+ plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
467
+
468
+
469
+
470
+ def seed_everything(seed=42):
471
+ os.environ['PYTHONHASHSEED'] = str(seed)
472
+ random.seed(seed)
473
+ np.random.seed(seed)
474
+ torch.manual_seed(seed)
475
+ torch.cuda.manual_seed(seed)
476
+ torch.cuda.manual_seed_all(seed)
477
+ torch.backends.cudnn.deterministic = True
478
+ torch.backends.cudnn.benchmark = False
479
+
480
+
481
+ def clip_coords(boxes, img_shape):
482
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
483
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
484
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
485
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
486
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
487
+
488
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
489
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
490
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
491
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
492
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
493
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
494
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
495
+ return y
496
+
497
+
498
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
499
+ # Convert normalized segments into pixel segments, shape (n,2)
500
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
501
+ y[..., 0] = w * x[..., 0] + padw # top left x
502
+ y[..., 1] = h * x[..., 1] + padh # top left y
503
+ return y
504
+
505
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
506
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
507
+ if clip:
508
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
509
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
510
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
511
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
512
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
513
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
514
+ return y
515
+
516
+ def clip_boxes(boxes, shape):
517
+ # Clip boxes (xyxy) to image shape (height, width)
518
+ if isinstance(boxes, torch.Tensor): # faster individually
519
+ boxes[..., 0].clamp_(0, shape[1]) # x1
520
+ boxes[..., 1].clamp_(0, shape[0]) # y1
521
+ boxes[..., 2].clamp_(0, shape[1]) # x2
522
+ boxes[..., 3].clamp_(0, shape[0]) # y2
523
+ else: # np.array (faster grouped)
524
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
525
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2