|
import os |
|
import yaml |
|
import json |
|
import argparse |
|
import logging |
|
from datetime import datetime |
|
from pathlib import Path |
|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from train.train import train_model |
|
from utils.logging_utils import init_logging |
|
from dataset.preprocess_multi_processing import run_parallel |
|
from train.train import evaluate_on_val |
|
from train.test import run_test_inference |
|
from dataset.lidar_dataset import LidarFusionDataset |
|
from model.model import SimpleMLP |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", type=Path, default="config.yaml", help="Path to YAML configuration.") |
|
parser.add_argument("--mode", choices=["preprocess", "train", "val", "test"], default="train") |
|
parser.add_argument("--weights_path", default="best_model.pth", help="Path to model weights.") |
|
args = parser.parse_args() |
|
cfg = yaml.safe_load(args.config.read_text()) |
|
now = datetime.now() |
|
out_dir = Path(cfg["logging"]["save_dir"] + now.strftime("_%m_%d_%H_%M_%S")) |
|
os.makedirs(out_dir, exist_ok=True) |
|
init_logging(out_dir / "run.log") |
|
mode = args.mode |
|
if mode == "preprocess": |
|
split_file = os.path.join(cfg["dataset"]["root"], cfg["dataset"]["split_file"]) |
|
with open(split_file) as f: |
|
split_dict = json.load(f) |
|
all_zones = split_dict["train"] + split_dict["val"] + split_dict["test"] |
|
run_parallel(cfg, all_zones, num_workers=cfg["dataset"]["pre-processing_num_workers"]) |
|
|
|
elif mode == "train": |
|
train_model(cfg, out_dir) |
|
|
|
elif mode == "val": |
|
model = SimpleMLP( |
|
input_dim=cfg['dataset']['n_classes'] * 2, |
|
hidden_dims=cfg['model']['hidden_dims'], |
|
n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device']) |
|
model.load_state_dict(torch.load(args.weights_path)) |
|
val_set = LidarFusionDataset(cfg, split="val", shuffle_zones=False) |
|
val_loader = DataLoader(val_set, batch_size=1, num_workers=0, shuffle=False) |
|
miou, mf1, ious = evaluate_on_val(model, val_loader, cfg) |
|
elif mode == "test": |
|
model = SimpleMLP( |
|
input_dim=cfg['dataset']['n_classes'] * 2, |
|
hidden_dims=cfg['model']['hidden_dims'], |
|
n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device']) |
|
model.load_state_dict(torch.load(args.weights_path)) |
|
test_set = LidarFusionDataset(cfg, split="test", shuffle_zones=False) |
|
test_loader = DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False) |
|
run_test_inference(model, test_loader, cfg, out_dir) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|