English
Shanci's picture
Upload folder using huggingface_hub
28a0087 verified
import os
import yaml
import json
import argparse
import logging
from datetime import datetime
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from train.train import train_model
from utils.logging_utils import init_logging
from dataset.preprocess_multi_processing import run_parallel
from train.train import evaluate_on_val
from train.test import run_test_inference
from dataset.lidar_dataset import LidarFusionDataset
from model.model import SimpleMLP
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=Path, default="config.yaml", help="Path to YAML configuration.")
parser.add_argument("--mode", choices=["preprocess", "train", "val", "test"], default="train")
parser.add_argument("--weights_path", default="best_model.pth", help="Path to model weights.")
args = parser.parse_args()
cfg = yaml.safe_load(args.config.read_text())
now = datetime.now()
out_dir = Path(cfg["logging"]["save_dir"] + now.strftime("_%m_%d_%H_%M_%S"))
os.makedirs(out_dir, exist_ok=True)
init_logging(out_dir / "run.log")
mode = args.mode
if mode == "preprocess":
split_file = os.path.join(cfg["dataset"]["root"], cfg["dataset"]["split_file"])
with open(split_file) as f:
split_dict = json.load(f)
all_zones = split_dict["train"] + split_dict["val"] + split_dict["test"]
run_parallel(cfg, all_zones, num_workers=cfg["dataset"]["pre-processing_num_workers"])
elif mode == "train":
train_model(cfg, out_dir)
elif mode == "val":
model = SimpleMLP(
input_dim=cfg['dataset']['n_classes'] * 2,
hidden_dims=cfg['model']['hidden_dims'],
n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device'])
model.load_state_dict(torch.load(args.weights_path))
val_set = LidarFusionDataset(cfg, split="val", shuffle_zones=False)
val_loader = DataLoader(val_set, batch_size=1, num_workers=0, shuffle=False)
miou, mf1, ious = evaluate_on_val(model, val_loader, cfg)
elif mode == "test":
model = SimpleMLP(
input_dim=cfg['dataset']['n_classes'] * 2,
hidden_dims=cfg['model']['hidden_dims'],
n_classes=cfg['dataset']['n_classes']).to(cfg['training']['device'])
model.load_state_dict(torch.load(args.weights_path))
test_set = LidarFusionDataset(cfg, split="test", shuffle_zones=False)
test_loader = DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False)
run_test_inference(model, test_loader, cfg, out_dir)
if __name__ == "__main__":
main()