File size: 5,540 Bytes
5ace3a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

from torch.utils.data import Dataset, DataLoader
from loss import YoloLoss
import config
import torch
from dataset import YOLODataset
from torch.optim.lr_scheduler import OneCycleLR
import random
from model import YOLOv3
import lightning.pytorch as pl

def criterion(out, y, anchors):
  loss_fn = YoloLoss()
  loss = (
          loss_fn(out[0], y[0], anchors[0])
          + loss_fn(out[1], y[1], anchors[1])
          + loss_fn(out[2], y[2], anchors[2]))
  return loss


def get_loader(train_dataset, test_dataset):
  train_loader = DataLoader(
          dataset=train_dataset,
          batch_size=config.BATCH_SIZE,
          num_workers=config.NUM_WORKERS,
          pin_memory=config.PIN_MEMORY,
          shuffle=True,
          drop_last=False,
  )

  test_loader = DataLoader(
      dataset=test_dataset,
      batch_size=config.BATCH_SIZE,
      num_workers=config.NUM_WORKERS,
      pin_memory=config.PIN_MEMORY,
      shuffle=False,
      drop_last=False,
      
  )

  return(train_loader, test_loader)


def accuracy_fn(y, out, threshold, 
                correct_class, correct_obj, 
                correct_noobj, tot_class_preds, 
                tot_obj, tot_noobj):

  for i in range(3):
      
      obj = y[i][..., 0] == 1 # in paper this is Iobj_i
      noobj = y[i][..., 0] == 0  # in paper this is Iobj_i

      correct_class += torch.sum(
          torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
      )
      tot_class_preds += torch.sum(obj)

      obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
      correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
      tot_obj += torch.sum(obj)
      correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
      tot_noobj += torch.sum(noobj)

  return((correct_class/(tot_class_preds+1e-16))*100, 
         (correct_noobj/(tot_noobj+1e-16))*100, 
         (correct_obj/(tot_obj+1e-16))*100)


def get_datasets(train_loc="/train.csv", test_loc="/test.csv"):

  train_dataset = YOLODataset(
      config.DATASET + train_loc,
      transform=config.train_transform,
      img_dir=config.IMG_DIR,
      label_dir=config.LABEL_DIR,
      anchors=config.ANCHORS,
  )

  test_dataset = YOLODataset(
      config.DATASET + test_loc,
      transform=config.test_transform,
      img_dir=config.IMG_DIR,
      label_dir=config.LABEL_DIR,
      anchors=config.ANCHORS,
      train=False
  )

  return(train_dataset, test_dataset)



class YOLOv3Lightning(pl.LightningModule):
  def __init__(self, dataset=None, lr=config.LEARNING_RATE):
    super().__init__()

    self.save_hyperparameters()

    self.model = YOLOv3(num_classes=config.NUM_CLASSES)
    self.lr = lr
    self.criterion = criterion
    self.losses = []
    self.threshold = config.CONF_THRESHOLD
    self.iou_threshold = config.NMS_IOU_THRESH
    self.train_idx = 0
    self.box_format="midpoint"
    self.dataset = dataset
    self.criterion = criterion
    self.accuracy_fn = accuracy_fn
    self.tot_class_preds, self.correct_class = 0, 0
    self.tot_noobj, self.correct_noobj = 0, 0
    self.tot_obj, self.correct_obj = 0, 0
    self.scaled_anchors = 0

  def forward(self, x):
    return self.model(x)

  def set_scaled_anchor(self, scaled_anchors):
      self.scaled_anchors = scaled_anchors

  def on_train_epoch_start(self):
      # Set a new image size for the dataset at the beginning of each epoch
      size_idx = random.choice(range(len(config.IMAGE_SIZES)))
      self.dataset.set_image_size(size_idx)
      self.set_scaled_anchor((
          torch.tensor(config.ANCHORS)
          * torch.tensor(config.S[size_idx]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
      ))

  def on_validation_epoch_start(self):
      self.set_scaled_anchor((
          torch.tensor(config.ANCHORS)
          * torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
      ))


  def training_step(self, batch, batch_idx):
    x, y = batch
    out = self(x)
    loss = self.criterion(out, y, self.scaled_anchors)

    self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True, logger=True)

    return loss

  def validation_step(self, val_batch, batch_idx):
    x, labels = val_batch
    out = self(x)

    loss = self.criterion(out, labels, self.scaled_anchors)
    self.log('val_loss', loss, prog_bar=True, on_epoch=True)

    self.evaluate(x, labels, out, 'val')


  def evaluate(self, x, y, out, stage=None):

    # Class Accuracy
    class_accuracy, no_obj_accuracy, obj_accuracy = self.accuracy_fn(y,
                                                                     out,
                                                                     self.threshold,
                                                                     self.correct_class,
                                                                     self.correct_obj,
                                                                     self.correct_noobj,
                                                                     self.tot_class_preds,
                                                                     self.tot_obj,
                                                                     self.tot_noobj, )
    if stage:
      self.log(f'{stage}_class_accuracy', class_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)
      self.log(f'{stage}_no_obj_accuracy', no_obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)
      self.log(f'{stage}_obj_accuracy', obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)