from torch.utils.data import Dataset, DataLoader from loss import YoloLoss import config import torch from dataset import YOLODataset from torch.optim.lr_scheduler import OneCycleLR import random from model import YOLOv3 import lightning.pytorch as pl def criterion(out, y, anchors): loss_fn = YoloLoss() loss = ( loss_fn(out[0], y[0], anchors[0]) + loss_fn(out[1], y[1], anchors[1]) + loss_fn(out[2], y[2], anchors[2])) return loss def get_loader(train_dataset, test_dataset): train_loader = DataLoader( dataset=train_dataset, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY, shuffle=True, drop_last=False, ) test_loader = DataLoader( dataset=test_dataset, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY, shuffle=False, drop_last=False, ) return(train_loader, test_loader) def accuracy_fn(y, out, threshold, correct_class, correct_obj, correct_noobj, tot_class_preds, tot_obj, tot_noobj): for i in range(3): obj = y[i][..., 0] == 1 # in paper this is Iobj_i noobj = y[i][..., 0] == 0 # in paper this is Iobj_i correct_class += torch.sum( torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj] ) tot_class_preds += torch.sum(obj) obj_preds = torch.sigmoid(out[i][..., 0]) > threshold correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj]) tot_obj += torch.sum(obj) correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj]) tot_noobj += torch.sum(noobj) return((correct_class/(tot_class_preds+1e-16))*100, (correct_noobj/(tot_noobj+1e-16))*100, (correct_obj/(tot_obj+1e-16))*100) def get_datasets(train_loc="/train.csv", test_loc="/test.csv"): train_dataset = YOLODataset( config.DATASET + train_loc, transform=config.train_transform, img_dir=config.IMG_DIR, label_dir=config.LABEL_DIR, anchors=config.ANCHORS, ) test_dataset = YOLODataset( config.DATASET + test_loc, transform=config.test_transform, img_dir=config.IMG_DIR, label_dir=config.LABEL_DIR, anchors=config.ANCHORS, train=False ) return(train_dataset, test_dataset) class YOLOv3Lightning(pl.LightningModule): def __init__(self, dataset=None, lr=config.LEARNING_RATE): super().__init__() self.save_hyperparameters() self.model = YOLOv3(num_classes=config.NUM_CLASSES) self.lr = lr self.criterion = criterion self.losses = [] self.threshold = config.CONF_THRESHOLD self.iou_threshold = config.NMS_IOU_THRESH self.train_idx = 0 self.box_format="midpoint" self.dataset = dataset self.criterion = criterion self.accuracy_fn = accuracy_fn self.tot_class_preds, self.correct_class = 0, 0 self.tot_noobj, self.correct_noobj = 0, 0 self.tot_obj, self.correct_obj = 0, 0 self.scaled_anchors = 0 def forward(self, x): return self.model(x) def set_scaled_anchor(self, scaled_anchors): self.scaled_anchors = scaled_anchors def on_train_epoch_start(self): # Set a new image size for the dataset at the beginning of each epoch size_idx = random.choice(range(len(config.IMAGE_SIZES))) self.dataset.set_image_size(size_idx) self.set_scaled_anchor(( torch.tensor(config.ANCHORS) * torch.tensor(config.S[size_idx]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) )) def on_validation_epoch_start(self): self.set_scaled_anchor(( torch.tensor(config.ANCHORS) * torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) )) def training_step(self, batch, batch_idx): x, y = batch out = self(x) loss = self.criterion(out, y, self.scaled_anchors) self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True, logger=True) return loss def validation_step(self, val_batch, batch_idx): x, labels = val_batch out = self(x) loss = self.criterion(out, labels, self.scaled_anchors) self.log('val_loss', loss, prog_bar=True, on_epoch=True) self.evaluate(x, labels, out, 'val') def evaluate(self, x, y, out, stage=None): # Class Accuracy class_accuracy, no_obj_accuracy, obj_accuracy = self.accuracy_fn(y, out, self.threshold, self.correct_class, self.correct_obj, self.correct_noobj, self.tot_class_preds, self.tot_obj, self.tot_noobj, ) if stage: self.log(f'{stage}_class_accuracy', class_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True) self.log(f'{stage}_no_obj_accuracy', no_obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True) self.log(f'{stage}_obj_accuracy', obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)