English
File size: 2,679 Bytes
e703e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28a0087
e703e79
 
 
 
 
 
 
 
28a0087
e703e79
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import yaml
import json
import argparse
import logging
from datetime import datetime
from pathlib import Path
import torch
from torch.utils.data import DataLoader


from train.train import train_model
from utils.logging_utils import init_logging
from dataset.preprocess_multi_processing import run_parallel
from train.train import evaluate_on_val
from train.test import run_test_inference
from dataset.lidar_dataset import LidarFusionDataset
from model.model import SimpleMLP


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=Path, default="config.yaml", help="Path to YAML configuration.")
    parser.add_argument("--mode", choices=["preprocess", "train", "val", "test"], default="train")
    parser.add_argument("--weights_path", default="best_model.pth", help="Path to model weights.")
    args = parser.parse_args()
    cfg = yaml.safe_load(args.config.read_text())
    now = datetime.now()
    out_dir = Path(cfg["logging"]["save_dir"] + now.strftime("_%m_%d_%H_%M_%S"))
    os.makedirs(out_dir, exist_ok=True)
    init_logging(out_dir / "run.log")
    mode = args.mode
    if mode == "preprocess":
        split_file = os.path.join(cfg["dataset"]["root"], cfg["dataset"]["split_file"])
        with open(split_file) as f:
            split_dict = json.load(f)
        all_zones = split_dict["train"] + split_dict["val"] + split_dict["test"]
        run_parallel(cfg, all_zones, num_workers=cfg["dataset"]["pre-processing_num_workers"])
        
    elif mode == "train":
        train_model(cfg, out_dir)
        
    elif mode == "val":
        model = SimpleMLP(
            input_dim=cfg['dataset']['n_classes'] * 2,
            hidden_dims=cfg['model']['hidden_dims'],
            n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device'])
        model.load_state_dict(torch.load(args.weights_path))
        val_set = LidarFusionDataset(cfg, split="val", shuffle_zones=False)
        val_loader = DataLoader(val_set, batch_size=1, num_workers=0, shuffle=False)
        miou, mf1, ious = evaluate_on_val(model, val_loader, cfg)
    elif mode == "test":
        model = SimpleMLP(
            input_dim=cfg['dataset']['n_classes'] * 2,
            hidden_dims=cfg['model']['hidden_dims'],
            n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device'])
        model.load_state_dict(torch.load(args.weights_path))
        test_set = LidarFusionDataset(cfg, split="test", shuffle_zones=False)
        test_loader = DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False)
        run_test_inference(model, test_loader, cfg, out_dir)
    
if __name__ == "__main__":
    main()