from lightning import LightningModule from torchmetrics.detection import MeanAveragePrecision from yolo.config.config import Config from yolo.model.yolo import create_model from yolo.tools.data_loader import create_dataloader from yolo.tools.loss_functions import create_loss_function from yolo.utils.bounding_box_utils import create_converter, to_metrics_format from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler class BaseModel(LightningModule): def __init__(self, cfg: Config): super().__init__() self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) def forward(self, x): return self.model(x) class ValidateModel(BaseModel): def __init__(self, cfg: Config): super().__init__(cfg) self.cfg = cfg if self.cfg.task.task == "validation": self.validation_cfg = self.cfg.task else: self.validation_cfg = self.cfg.task.validation self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy") self.metric.warn_on_many_detections = False self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task) def setup(self, stage): self.vec2box = create_converter( self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device ) self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms) def val_dataloader(self): return self.val_loader def validation_step(self, batch, batch_idx): batch_size, images, targets, rev_tensor, img_paths = batch predicts = self.post_proccess(self(images)) batch_metrics = self.metric( [to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets] ) self.log_dict( { "map": batch_metrics["map"], "map_50": batch_metrics["map_50"], }, on_step=True, batch_size=batch_size, ) return predicts def on_validation_epoch_end(self): epoch_metrics = self.metric.compute() del epoch_metrics["classes"] self.log_dict( {"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, rank_zero_only=True ) self.log_dict(epoch_metrics, prog_bar=True, logger=True, on_epoch=True, rank_zero_only=True) self.metric.reset() class TrainModel(ValidateModel): def __init__(self, cfg: Config): super().__init__(cfg) self.cfg = cfg self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task) def setup(self, stage): super().setup(stage) self.loss_fn = create_loss_function(self.cfg, self.vec2box) def train_dataloader(self): return self.train_loader def on_train_epoch_start(self): self.trainer.optimizers[0].next_epoch(len(self.train_loader)) def training_step(self, batch, batch_idx): self.trainer.optimizers[0].next_batch() batch_size, images, targets, *_ = batch predicts = self(images) aux_predicts = self.vec2box(predicts["AUX"]) main_predicts = self.vec2box(predicts["Main"]) loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets) self.log_dict( loss_item, prog_bar=True, logger=True, on_epoch=True, batch_size=batch_size, rank_zero_only=True, ) return loss * batch_size def configure_optimizers(self): optimizer = create_optimizer(self.model, self.cfg.task.optimizer) scheduler = create_scheduler(optimizer, self.cfg.task.scheduler) return [optimizer], [scheduler]