Spaces:
Runtime error
Runtime error
from torch.utils.data import Dataset, DataLoader | |
from loss import YoloLoss | |
import config | |
import torch | |
from dataset import YOLODataset | |
from torch.optim.lr_scheduler import OneCycleLR | |
import random | |
from model import YOLOv3 | |
import lightning.pytorch as pl | |
def criterion(out, y, anchors): | |
loss_fn = YoloLoss() | |
loss = ( | |
loss_fn(out[0], y[0], anchors[0]) | |
+ loss_fn(out[1], y[1], anchors[1]) | |
+ loss_fn(out[2], y[2], anchors[2])) | |
return loss | |
def get_loader(train_dataset, test_dataset): | |
train_loader = DataLoader( | |
dataset=train_dataset, | |
batch_size=config.BATCH_SIZE, | |
num_workers=config.NUM_WORKERS, | |
pin_memory=config.PIN_MEMORY, | |
shuffle=True, | |
drop_last=False, | |
) | |
test_loader = DataLoader( | |
dataset=test_dataset, | |
batch_size=config.BATCH_SIZE, | |
num_workers=config.NUM_WORKERS, | |
pin_memory=config.PIN_MEMORY, | |
shuffle=False, | |
drop_last=False, | |
) | |
return(train_loader, test_loader) | |
def accuracy_fn(y, out, threshold, | |
correct_class, correct_obj, | |
correct_noobj, tot_class_preds, | |
tot_obj, tot_noobj): | |
for i in range(3): | |
obj = y[i][..., 0] == 1 # in paper this is Iobj_i | |
noobj = y[i][..., 0] == 0 # in paper this is Iobj_i | |
correct_class += torch.sum( | |
torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj] | |
) | |
tot_class_preds += torch.sum(obj) | |
obj_preds = torch.sigmoid(out[i][..., 0]) > threshold | |
correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj]) | |
tot_obj += torch.sum(obj) | |
correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj]) | |
tot_noobj += torch.sum(noobj) | |
return((correct_class/(tot_class_preds+1e-16))*100, | |
(correct_noobj/(tot_noobj+1e-16))*100, | |
(correct_obj/(tot_obj+1e-16))*100) | |
def get_datasets(train_loc="/train.csv", test_loc="/test.csv"): | |
train_dataset = YOLODataset( | |
config.DATASET + train_loc, | |
transform=config.train_transform, | |
img_dir=config.IMG_DIR, | |
label_dir=config.LABEL_DIR, | |
anchors=config.ANCHORS, | |
) | |
test_dataset = YOLODataset( | |
config.DATASET + test_loc, | |
transform=config.test_transform, | |
img_dir=config.IMG_DIR, | |
label_dir=config.LABEL_DIR, | |
anchors=config.ANCHORS, | |
train=False | |
) | |
return(train_dataset, test_dataset) | |
class YOLOv3Lightning(pl.LightningModule): | |
def __init__(self, dataset=None, lr=config.LEARNING_RATE): | |
super().__init__() | |
self.save_hyperparameters() | |
self.model = YOLOv3(num_classes=config.NUM_CLASSES) | |
self.lr = lr | |
self.criterion = criterion | |
self.losses = [] | |
self.threshold = config.CONF_THRESHOLD | |
self.iou_threshold = config.NMS_IOU_THRESH | |
self.train_idx = 0 | |
self.box_format="midpoint" | |
self.dataset = dataset | |
self.criterion = criterion | |
self.accuracy_fn = accuracy_fn | |
self.tot_class_preds, self.correct_class = 0, 0 | |
self.tot_noobj, self.correct_noobj = 0, 0 | |
self.tot_obj, self.correct_obj = 0, 0 | |
self.scaled_anchors = 0 | |
def forward(self, x): | |
return self.model(x) | |
def set_scaled_anchor(self, scaled_anchors): | |
self.scaled_anchors = scaled_anchors | |
def on_train_epoch_start(self): | |
# Set a new image size for the dataset at the beginning of each epoch | |
size_idx = random.choice(range(len(config.IMAGE_SIZES))) | |
self.dataset.set_image_size(size_idx) | |
self.set_scaled_anchor(( | |
torch.tensor(config.ANCHORS) | |
* torch.tensor(config.S[size_idx]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
)) | |
def on_validation_epoch_start(self): | |
self.set_scaled_anchor(( | |
torch.tensor(config.ANCHORS) | |
* torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
)) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
out = self(x) | |
loss = self.criterion(out, y, self.scaled_anchors) | |
self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True, logger=True) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
x, labels = val_batch | |
out = self(x) | |
loss = self.criterion(out, labels, self.scaled_anchors) | |
self.log('val_loss', loss, prog_bar=True, on_epoch=True) | |
self.evaluate(x, labels, out, 'val') | |
def evaluate(self, x, y, out, stage=None): | |
# Class Accuracy | |
class_accuracy, no_obj_accuracy, obj_accuracy = self.accuracy_fn(y, | |
out, | |
self.threshold, | |
self.correct_class, | |
self.correct_obj, | |
self.correct_noobj, | |
self.tot_class_preds, | |
self.tot_obj, | |
self.tot_noobj, ) | |
if stage: | |
self.log(f'{stage}_class_accuracy', class_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True) | |
self.log(f'{stage}_no_obj_accuracy', no_obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True) | |
self.log(f'{stage}_obj_accuracy', obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True) | |