|
from lightning import LightningModule |
|
from torchmetrics.detection import MeanAveragePrecision |
|
|
|
from yolo.config.config import Config |
|
from yolo.model.yolo import create_model |
|
from yolo.tools.data_loader import create_dataloader |
|
from yolo.tools.loss_functions import create_loss_function |
|
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format |
|
from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler |
|
|
|
|
|
class BaseModel(LightningModule): |
|
def __init__(self, cfg: Config): |
|
super().__init__() |
|
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class ValidateModel(BaseModel): |
|
def __init__(self, cfg: Config): |
|
super().__init__(cfg) |
|
self.cfg = cfg |
|
if self.cfg.task.task == "validation": |
|
self.validation_cfg = self.cfg.task |
|
else: |
|
self.validation_cfg = self.cfg.task.validation |
|
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy") |
|
self.metric.warn_on_many_detections = False |
|
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task) |
|
|
|
def setup(self, stage): |
|
self.vec2box = create_converter( |
|
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device |
|
) |
|
self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms) |
|
|
|
def val_dataloader(self): |
|
return self.val_loader |
|
|
|
def validation_step(self, batch, batch_idx): |
|
batch_size, images, targets, rev_tensor, img_paths = batch |
|
predicts = self.post_proccess(self(images)) |
|
batch_metrics = self.metric( |
|
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets] |
|
) |
|
|
|
self.log_dict( |
|
{ |
|
"map": batch_metrics["map"], |
|
"map_50": batch_metrics["map_50"], |
|
}, |
|
on_step=True, |
|
batch_size=batch_size, |
|
) |
|
return predicts |
|
|
|
def on_validation_epoch_end(self): |
|
epoch_metrics = self.metric.compute() |
|
del epoch_metrics["classes"] |
|
self.log_dict( |
|
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, rank_zero_only=True |
|
) |
|
self.log_dict(epoch_metrics, prog_bar=True, logger=True, on_epoch=True, rank_zero_only=True) |
|
self.metric.reset() |
|
|
|
|
|
class TrainModel(ValidateModel): |
|
def __init__(self, cfg: Config): |
|
super().__init__(cfg) |
|
self.cfg = cfg |
|
self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task) |
|
|
|
def setup(self, stage): |
|
super().setup(stage) |
|
self.loss_fn = create_loss_function(self.cfg, self.vec2box) |
|
|
|
def train_dataloader(self): |
|
return self.train_loader |
|
|
|
def on_train_epoch_start(self): |
|
self.trainer.optimizers[0].next_epoch(len(self.train_loader)) |
|
|
|
def training_step(self, batch, batch_idx): |
|
self.trainer.optimizers[0].next_batch() |
|
batch_size, images, targets, *_ = batch |
|
predicts = self(images) |
|
aux_predicts = self.vec2box(predicts["AUX"]) |
|
main_predicts = self.vec2box(predicts["Main"]) |
|
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets) |
|
self.log_dict( |
|
loss_item, |
|
prog_bar=True, |
|
logger=True, |
|
on_epoch=True, |
|
batch_size=batch_size, |
|
rank_zero_only=True, |
|
) |
|
return loss * batch_size |
|
|
|
def configure_optimizers(self): |
|
optimizer = create_optimizer(self.model, self.cfg.task.optimizer) |
|
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler) |
|
return [optimizer], [scheduler] |
|
|