File size: 3,959 Bytes
be6ec35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.optim as optim
import lightning.pytorch as pl
from tqdm import tqdm
from .model import YOLOv3
from .loss import YoloLoss
from .utils import get_loaders, load_checkpoint, check_class_accuracy, intersection_over_union
from . import config
from torch.optim.lr_scheduler import OneCycleLR


class YOLOv3Lightning(pl.LightningModule):
    def __init__(self, config, lr_value=0):
        super().__init__()
        self.automatic_optimization =True
        self.config = config
        self.model = YOLOv3(num_classes=self.config.NUM_CLASSES)
        self.loss_fn = YoloLoss()

        if lr_value == 0:
          self.learning_rate = self.config.LEARNING_RATE
        else:
          self.learning_rate = lr_value

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.config.LEARNING_RATE, weight_decay=self.config.WEIGHT_DECAY)
        EPOCHS = self.config.NUM_EPOCHS * 2 // 5
        scheduler = OneCycleLR(optimizer, max_lr=1E-3, steps_per_epoch=len(self.train_dataloader()), epochs=EPOCHS, pct_start=5/EPOCHS, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy='linear')
        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
    
    def train_dataloader(self):
        train_loader, _, _ = get_loaders(
            train_csv_path=self.config.DATASET + "/train.csv",
            test_csv_path=self.config.DATASET + "/test.csv",
        )
        return train_loader

    def training_step(self, batch, batch_idx):
        x, y = batch
        y0, y1, y2 = (y[0].to(self.device),y[1].to(self.device),y[2].to(self.device))
        out = self(x)

        loss = (self.loss_fn(out[0], y0, self.scaled_anchors[0])
                + self.loss_fn(out[1], y1, self.scaled_anchors[1])
                + self.loss_fn(out[2], y2, self.scaled_anchors[2]))
        
        self.log('train_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        return loss

    def val_dataloader(self):
        _,  _, val_loader = get_loaders(
            train_csv_path=self.config.DATASET + "/train.csv",
            test_csv_path=self.config.DATASET + "/test.csv",
        )

        return val_loader

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y0, y1, y2 = (
            y[0].to(self.device),
            y[1].to(self.device),
            y[2].to(self.device),
        )
        out = self(x)
        loss = (
            self.loss_fn(out[0], y0, self.scaled_anchors[0])
            + self.loss_fn(out[1], y1, self.scaled_anchors[1])
            + self.loss_fn(out[2], y2, self.scaled_anchors[2])
        )

        self.log('val_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
  

    def test_dataloader(self):
        _, test_loader, _ = get_loaders(
            train_csv_path=self.config.DATASET + "/train.csv",
            test_csv_path=self.config.DATASET + "/test.csv",
        )
        return test_loader

    def test_step(self, batch, batch_idx):
        x, y = batch
        y0, y1, y2 = (
            y[0].to(self.device),
            y[1].to(self.device),
            y[2].to(self.device),
        )
        out = self(x)
        loss = (
            self.loss_fn(out[0], y0, self.scaled_anchors[0])
            + self.loss_fn(out[1], y1, self.scaled_anchors[1])
            + self.loss_fn(out[2], y2, self.scaled_anchors[2])
        )
        self.log('test_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)

    def on_train_start(self):
        if self.config.LOAD_MODEL:
            load_checkpoint(self.config.CHECKPOINT_FILE, self.model, self.optimizers(), self.config.LEARNING_RATE)
        self.scaled_anchors = (
            torch.tensor(self.config.ANCHORS)
            * torch.tensor(self.config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
        ).to(self.device)