YOLOV3-GradCAM / lightning_utils.py
Sijuade's picture
Upload lightning_utils.py
5ace3a9
from torch.utils.data import Dataset, DataLoader
from loss import YoloLoss
import config
import torch
from dataset import YOLODataset
from torch.optim.lr_scheduler import OneCycleLR
import random
from model import YOLOv3
import lightning.pytorch as pl
def criterion(out, y, anchors):
loss_fn = YoloLoss()
loss = (
loss_fn(out[0], y[0], anchors[0])
+ loss_fn(out[1], y[1], anchors[1])
+ loss_fn(out[2], y[2], anchors[2]))
return loss
def get_loader(train_dataset, test_dataset):
train_loader = DataLoader(
dataset=train_dataset,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=True,
drop_last=False,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=False,
drop_last=False,
)
return(train_loader, test_loader)
def accuracy_fn(y, out, threshold,
correct_class, correct_obj,
correct_noobj, tot_class_preds,
tot_obj, tot_noobj):
for i in range(3):
obj = y[i][..., 0] == 1 # in paper this is Iobj_i
noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
correct_class += torch.sum(
torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
)
tot_class_preds += torch.sum(obj)
obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
tot_obj += torch.sum(obj)
correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
tot_noobj += torch.sum(noobj)
return((correct_class/(tot_class_preds+1e-16))*100,
(correct_noobj/(tot_noobj+1e-16))*100,
(correct_obj/(tot_obj+1e-16))*100)
def get_datasets(train_loc="/train.csv", test_loc="/test.csv"):
train_dataset = YOLODataset(
config.DATASET + train_loc,
transform=config.train_transform,
img_dir=config.IMG_DIR,
label_dir=config.LABEL_DIR,
anchors=config.ANCHORS,
)
test_dataset = YOLODataset(
config.DATASET + test_loc,
transform=config.test_transform,
img_dir=config.IMG_DIR,
label_dir=config.LABEL_DIR,
anchors=config.ANCHORS,
train=False
)
return(train_dataset, test_dataset)
class YOLOv3Lightning(pl.LightningModule):
def __init__(self, dataset=None, lr=config.LEARNING_RATE):
super().__init__()
self.save_hyperparameters()
self.model = YOLOv3(num_classes=config.NUM_CLASSES)
self.lr = lr
self.criterion = criterion
self.losses = []
self.threshold = config.CONF_THRESHOLD
self.iou_threshold = config.NMS_IOU_THRESH
self.train_idx = 0
self.box_format="midpoint"
self.dataset = dataset
self.criterion = criterion
self.accuracy_fn = accuracy_fn
self.tot_class_preds, self.correct_class = 0, 0
self.tot_noobj, self.correct_noobj = 0, 0
self.tot_obj, self.correct_obj = 0, 0
self.scaled_anchors = 0
def forward(self, x):
return self.model(x)
def set_scaled_anchor(self, scaled_anchors):
self.scaled_anchors = scaled_anchors
def on_train_epoch_start(self):
# Set a new image size for the dataset at the beginning of each epoch
size_idx = random.choice(range(len(config.IMAGE_SIZES)))
self.dataset.set_image_size(size_idx)
self.set_scaled_anchor((
torch.tensor(config.ANCHORS)
* torch.tensor(config.S[size_idx]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
))
def on_validation_epoch_start(self):
self.set_scaled_anchor((
torch.tensor(config.ANCHORS)
* torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
))
def training_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = self.criterion(out, y, self.scaled_anchors)
self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True, logger=True)
return loss
def validation_step(self, val_batch, batch_idx):
x, labels = val_batch
out = self(x)
loss = self.criterion(out, labels, self.scaled_anchors)
self.log('val_loss', loss, prog_bar=True, on_epoch=True)
self.evaluate(x, labels, out, 'val')
def evaluate(self, x, y, out, stage=None):
# Class Accuracy
class_accuracy, no_obj_accuracy, obj_accuracy = self.accuracy_fn(y,
out,
self.threshold,
self.correct_class,
self.correct_obj,
self.correct_noobj,
self.tot_class_preds,
self.tot_obj,
self.tot_noobj, )
if stage:
self.log(f'{stage}_class_accuracy', class_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)
self.log(f'{stage}_no_obj_accuracy', no_obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)
self.log(f'{stage}_obj_accuracy', obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True)