File size: 2,679 Bytes
e703e79 28a0087 e703e79 28a0087 e703e79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import os
import yaml
import json
import argparse
import logging
from datetime import datetime
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from train.train import train_model
from utils.logging_utils import init_logging
from dataset.preprocess_multi_processing import run_parallel
from train.train import evaluate_on_val
from train.test import run_test_inference
from dataset.lidar_dataset import LidarFusionDataset
from model.model import SimpleMLP
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=Path, default="config.yaml", help="Path to YAML configuration.")
parser.add_argument("--mode", choices=["preprocess", "train", "val", "test"], default="train")
parser.add_argument("--weights_path", default="best_model.pth", help="Path to model weights.")
args = parser.parse_args()
cfg = yaml.safe_load(args.config.read_text())
now = datetime.now()
out_dir = Path(cfg["logging"]["save_dir"] + now.strftime("_%m_%d_%H_%M_%S"))
os.makedirs(out_dir, exist_ok=True)
init_logging(out_dir / "run.log")
mode = args.mode
if mode == "preprocess":
split_file = os.path.join(cfg["dataset"]["root"], cfg["dataset"]["split_file"])
with open(split_file) as f:
split_dict = json.load(f)
all_zones = split_dict["train"] + split_dict["val"] + split_dict["test"]
run_parallel(cfg, all_zones, num_workers=cfg["dataset"]["pre-processing_num_workers"])
elif mode == "train":
train_model(cfg, out_dir)
elif mode == "val":
model = SimpleMLP(
input_dim=cfg['dataset']['n_classes'] * 2,
hidden_dims=cfg['model']['hidden_dims'],
n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device'])
model.load_state_dict(torch.load(args.weights_path))
val_set = LidarFusionDataset(cfg, split="val", shuffle_zones=False)
val_loader = DataLoader(val_set, batch_size=1, num_workers=0, shuffle=False)
miou, mf1, ious = evaluate_on_val(model, val_loader, cfg)
elif mode == "test":
model = SimpleMLP(
input_dim=cfg['dataset']['n_classes'] * 2,
hidden_dims=cfg['model']['hidden_dims'],
n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device'])
model.load_state_dict(torch.load(args.weights_path))
test_set = LidarFusionDataset(cfg, split="test", shuffle_zones=False)
test_loader = DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False)
run_test_inference(model, test_loader, cfg, out_dir)
if __name__ == "__main__":
main()
|