Spaces:
Sleeping
Sleeping
File size: 3,959 Bytes
be6ec35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
import torch.optim as optim
import lightning.pytorch as pl
from tqdm import tqdm
from .model import YOLOv3
from .loss import YoloLoss
from .utils import get_loaders, load_checkpoint, check_class_accuracy, intersection_over_union
from . import config
from torch.optim.lr_scheduler import OneCycleLR
class YOLOv3Lightning(pl.LightningModule):
def __init__(self, config, lr_value=0):
super().__init__()
self.automatic_optimization =True
self.config = config
self.model = YOLOv3(num_classes=self.config.NUM_CLASSES)
self.loss_fn = YoloLoss()
if lr_value == 0:
self.learning_rate = self.config.LEARNING_RATE
else:
self.learning_rate = lr_value
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.config.LEARNING_RATE, weight_decay=self.config.WEIGHT_DECAY)
EPOCHS = self.config.NUM_EPOCHS * 2 // 5
scheduler = OneCycleLR(optimizer, max_lr=1E-3, steps_per_epoch=len(self.train_dataloader()), epochs=EPOCHS, pct_start=5/EPOCHS, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy='linear')
return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
def train_dataloader(self):
train_loader, _, _ = get_loaders(
train_csv_path=self.config.DATASET + "/train.csv",
test_csv_path=self.config.DATASET + "/test.csv",
)
return train_loader
def training_step(self, batch, batch_idx):
x, y = batch
y0, y1, y2 = (y[0].to(self.device),y[1].to(self.device),y[2].to(self.device))
out = self(x)
loss = (self.loss_fn(out[0], y0, self.scaled_anchors[0])
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
+ self.loss_fn(out[2], y2, self.scaled_anchors[2]))
self.log('train_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
return loss
def val_dataloader(self):
_, _, val_loader = get_loaders(
train_csv_path=self.config.DATASET + "/train.csv",
test_csv_path=self.config.DATASET + "/test.csv",
)
return val_loader
def validation_step(self, batch, batch_idx):
x, y = batch
y0, y1, y2 = (
y[0].to(self.device),
y[1].to(self.device),
y[2].to(self.device),
)
out = self(x)
loss = (
self.loss_fn(out[0], y0, self.scaled_anchors[0])
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
+ self.loss_fn(out[2], y2, self.scaled_anchors[2])
)
self.log('val_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
def test_dataloader(self):
_, test_loader, _ = get_loaders(
train_csv_path=self.config.DATASET + "/train.csv",
test_csv_path=self.config.DATASET + "/test.csv",
)
return test_loader
def test_step(self, batch, batch_idx):
x, y = batch
y0, y1, y2 = (
y[0].to(self.device),
y[1].to(self.device),
y[2].to(self.device),
)
out = self(x)
loss = (
self.loss_fn(out[0], y0, self.scaled_anchors[0])
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
+ self.loss_fn(out[2], y2, self.scaled_anchors[2])
)
self.log('test_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
def on_train_start(self):
if self.config.LOAD_MODEL:
load_checkpoint(self.config.CHECKPOINT_FILE, self.model, self.optimizers(), self.config.LEARNING_RATE)
self.scaled_anchors = (
torch.tensor(self.config.ANCHORS)
* torch.tensor(self.config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(self.device)
|