File size: 3,898 Bytes
8b3b3ef ecf6aba 8b3b3ef d1477fc 8b3b3ef 1197f7d 73b88fc 8b3b3ef 1197f7d 8b3b3ef afa32b4 b2baf14 8b3b3ef 1197f7d 8b3b3ef 604c897 6e46676 8b3b3ef b2baf14 8b3b3ef 604c897 73b88fc 8b3b3ef 73b88fc 8b3b3ef 240dcb0 8b3b3ef 240dcb0 8b3b3ef 604c897 8b3b3ef 604c897 8b3b3ef 604c897 8b3b3ef 240dcb0 8b3b3ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from lightning import LightningModule
from torchmetrics.detection import MeanAveragePrecision
from yolo.config.config import Config
from yolo.model.yolo import create_model
from yolo.tools.data_loader import create_dataloader
from yolo.tools.loss_functions import create_loss_function
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
from yolo.utils.model_utils import PostProccess, create_optimizer, create_scheduler
class BaseModel(LightningModule):
def __init__(self, cfg: Config):
super().__init__()
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
def forward(self, x):
return self.model(x)
class ValidateModel(BaseModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
if self.cfg.task.task == "validation":
self.validation_cfg = self.cfg.task
else:
self.validation_cfg = self.cfg.task.validation
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy")
self.metric.warn_on_many_detections = False
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
def setup(self, stage):
self.vec2box = create_converter(
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
)
self.post_proccess = PostProccess(self.vec2box, self.validation_cfg.nms)
def val_dataloader(self):
return self.val_loader
def validation_step(self, batch, batch_idx):
batch_size, images, targets, rev_tensor, img_paths = batch
predicts = self.post_proccess(self(images))
batch_metrics = self.metric(
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
)
self.log_dict(
{
"map": batch_metrics["map"],
"map_50": batch_metrics["map_50"],
},
on_step=True,
batch_size=batch_size,
)
return predicts
def on_validation_epoch_end(self):
epoch_metrics = self.metric.compute()
del epoch_metrics["classes"]
self.log_dict(
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]}, rank_zero_only=True
)
self.log_dict(epoch_metrics, prog_bar=True, logger=True, on_epoch=True, rank_zero_only=True)
self.metric.reset()
class TrainModel(ValidateModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
def setup(self, stage):
super().setup(stage)
self.loss_fn = create_loss_function(self.cfg, self.vec2box)
def train_dataloader(self):
return self.train_loader
def on_train_epoch_start(self):
self.trainer.optimizers[0].next_epoch(len(self.train_loader))
def training_step(self, batch, batch_idx):
self.trainer.optimizers[0].next_batch()
batch_size, images, targets, *_ = batch
predicts = self(images)
aux_predicts = self.vec2box(predicts["AUX"])
main_predicts = self.vec2box(predicts["Main"])
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
self.log_dict(
loss_item,
prog_bar=True,
logger=True,
on_epoch=True,
batch_size=batch_size,
rank_zero_only=True,
)
return loss * batch_size
def configure_optimizers(self):
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
return [optimizer], [scheduler]
|