Sijuade commited on
Commit
5ace3a9
·
1 Parent(s): a00793d

Upload lightning_utils.py

Browse files
Files changed (1) hide show
  1. lightning_utils.py +171 -0
lightning_utils.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from loss import YoloLoss
4
+ import config
5
+ import torch
6
+ from dataset import YOLODataset
7
+ from torch.optim.lr_scheduler import OneCycleLR
8
+ import random
9
+ from model import YOLOv3
10
+ import lightning.pytorch as pl
11
+
12
+ def criterion(out, y, anchors):
13
+ loss_fn = YoloLoss()
14
+ loss = (
15
+ loss_fn(out[0], y[0], anchors[0])
16
+ + loss_fn(out[1], y[1], anchors[1])
17
+ + loss_fn(out[2], y[2], anchors[2]))
18
+ return loss
19
+
20
+
21
+ def get_loader(train_dataset, test_dataset):
22
+ train_loader = DataLoader(
23
+ dataset=train_dataset,
24
+ batch_size=config.BATCH_SIZE,
25
+ num_workers=config.NUM_WORKERS,
26
+ pin_memory=config.PIN_MEMORY,
27
+ shuffle=True,
28
+ drop_last=False,
29
+ )
30
+
31
+ test_loader = DataLoader(
32
+ dataset=test_dataset,
33
+ batch_size=config.BATCH_SIZE,
34
+ num_workers=config.NUM_WORKERS,
35
+ pin_memory=config.PIN_MEMORY,
36
+ shuffle=False,
37
+ drop_last=False,
38
+
39
+ )
40
+
41
+ return(train_loader, test_loader)
42
+
43
+
44
+ def accuracy_fn(y, out, threshold,
45
+ correct_class, correct_obj,
46
+ correct_noobj, tot_class_preds,
47
+ tot_obj, tot_noobj):
48
+
49
+ for i in range(3):
50
+
51
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
52
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
53
+
54
+ correct_class += torch.sum(
55
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
56
+ )
57
+ tot_class_preds += torch.sum(obj)
58
+
59
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
60
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
61
+ tot_obj += torch.sum(obj)
62
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
63
+ tot_noobj += torch.sum(noobj)
64
+
65
+ return((correct_class/(tot_class_preds+1e-16))*100,
66
+ (correct_noobj/(tot_noobj+1e-16))*100,
67
+ (correct_obj/(tot_obj+1e-16))*100)
68
+
69
+
70
+ def get_datasets(train_loc="/train.csv", test_loc="/test.csv"):
71
+
72
+ train_dataset = YOLODataset(
73
+ config.DATASET + train_loc,
74
+ transform=config.train_transform,
75
+ img_dir=config.IMG_DIR,
76
+ label_dir=config.LABEL_DIR,
77
+ anchors=config.ANCHORS,
78
+ )
79
+
80
+ test_dataset = YOLODataset(
81
+ config.DATASET + test_loc,
82
+ transform=config.test_transform,
83
+ img_dir=config.IMG_DIR,
84
+ label_dir=config.LABEL_DIR,
85
+ anchors=config.ANCHORS,
86
+ train=False
87
+ )
88
+
89
+ return(train_dataset, test_dataset)
90
+
91
+
92
+
93
+ class YOLOv3Lightning(pl.LightningModule):
94
+ def __init__(self, dataset=None, lr=config.LEARNING_RATE):
95
+ super().__init__()
96
+
97
+ self.save_hyperparameters()
98
+
99
+ self.model = YOLOv3(num_classes=config.NUM_CLASSES)
100
+ self.lr = lr
101
+ self.criterion = criterion
102
+ self.losses = []
103
+ self.threshold = config.CONF_THRESHOLD
104
+ self.iou_threshold = config.NMS_IOU_THRESH
105
+ self.train_idx = 0
106
+ self.box_format="midpoint"
107
+ self.dataset = dataset
108
+ self.criterion = criterion
109
+ self.accuracy_fn = accuracy_fn
110
+ self.tot_class_preds, self.correct_class = 0, 0
111
+ self.tot_noobj, self.correct_noobj = 0, 0
112
+ self.tot_obj, self.correct_obj = 0, 0
113
+ self.scaled_anchors = 0
114
+
115
+ def forward(self, x):
116
+ return self.model(x)
117
+
118
+ def set_scaled_anchor(self, scaled_anchors):
119
+ self.scaled_anchors = scaled_anchors
120
+
121
+ def on_train_epoch_start(self):
122
+ # Set a new image size for the dataset at the beginning of each epoch
123
+ size_idx = random.choice(range(len(config.IMAGE_SIZES)))
124
+ self.dataset.set_image_size(size_idx)
125
+ self.set_scaled_anchor((
126
+ torch.tensor(config.ANCHORS)
127
+ * torch.tensor(config.S[size_idx]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
128
+ ))
129
+
130
+ def on_validation_epoch_start(self):
131
+ self.set_scaled_anchor((
132
+ torch.tensor(config.ANCHORS)
133
+ * torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
134
+ ))
135
+
136
+
137
+ def training_step(self, batch, batch_idx):
138
+ x, y = batch
139
+ out = self(x)
140
+ loss = self.criterion(out, y, self.scaled_anchors)
141
+
142
+ self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True, logger=True)
143
+
144
+ return loss
145
+
146
+ def validation_step(self, val_batch, batch_idx):
147
+ x, labels = val_batch
148
+ out = self(x)
149
+
150
+ loss = self.criterion(out, labels, self.scaled_anchors)
151
+ self.log('val_loss', loss, prog_bar=True, on_epoch=True)
152
+
153
+ self.evaluate(x, labels, out, 'val')
154
+
155
+
156
+ def evaluate(self, x, y, out, stage=None):
157
+
158
+ # Class Accuracy
159
+ class_accuracy, no_obj_accuracy, obj_accuracy = self.accuracy_fn(y,
160
+ out,
161
+ self.threshold,
162
+ self.correct_class,
163
+ self.correct_obj,
164
+ self.correct_noobj,
165
+ self.tot_class_preds,
166
+ self.tot_obj,
167
+ self.tot_noobj, )
168
+ if stage:
169
+ self.log(f'{stage}_class_accuracy', class_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)
170
+ self.log(f'{stage}_no_obj_accuracy', no_obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)
171
+ self.log(f'{stage}_obj_accuracy', obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)