Spaces:
Running
Running
| """ | |
| Main file to launch training and testing experiments. | |
| """ | |
| import yaml | |
| import os | |
| import argparse | |
| import numpy as np | |
| import torch | |
| from .config.project_config import Config as cfg | |
| from .train import train_net | |
| from .export import export_predictions, export_homograpy_adaptation | |
| # Pytorch configurations | |
| torch.cuda.empty_cache() | |
| torch.backends.cudnn.benchmark = True | |
| def load_config(config_path): | |
| """Load configurations from a given yaml file.""" | |
| # Check file exists | |
| if not os.path.exists(config_path): | |
| raise ValueError("[Error] The provided config path is not valid.") | |
| # Load the configuration | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def update_config(path, model_cfg=None, dataset_cfg=None): | |
| """Update configuration file from the resume path.""" | |
| # Check we need to update or completely override. | |
| model_cfg = {} if model_cfg is None else model_cfg | |
| dataset_cfg = {} if dataset_cfg is None else dataset_cfg | |
| # Load saved configs | |
| with open(os.path.join(path, "model_cfg.yaml"), "r") as f: | |
| model_cfg_saved = yaml.safe_load(f) | |
| model_cfg.update(model_cfg_saved) | |
| with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f: | |
| dataset_cfg_saved = yaml.safe_load(f) | |
| dataset_cfg.update(dataset_cfg_saved) | |
| # Update the saved yaml file | |
| if not model_cfg == model_cfg_saved: | |
| with open(os.path.join(path, "model_cfg.yaml"), "w") as f: | |
| yaml.dump(model_cfg, f) | |
| if not dataset_cfg == dataset_cfg_saved: | |
| with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f: | |
| yaml.dump(dataset_cfg, f) | |
| return model_cfg, dataset_cfg | |
| def record_config(model_cfg, dataset_cfg, output_path): | |
| """Record dataset config to the log path.""" | |
| # Record model config | |
| with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f: | |
| yaml.safe_dump(model_cfg, f) | |
| # Record dataset config | |
| with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f: | |
| yaml.safe_dump(dataset_cfg, f) | |
| def train(args, dataset_cfg, model_cfg, output_path): | |
| """Training function.""" | |
| # Update model config from the resume path (only in resume mode) | |
| if args.resume: | |
| if os.path.realpath(output_path) != os.path.realpath(args.resume_path): | |
| record_config(model_cfg, dataset_cfg, output_path) | |
| # First time, then write the config file to the output path | |
| else: | |
| record_config(model_cfg, dataset_cfg, output_path) | |
| # Launch the training | |
| train_net(args, dataset_cfg, model_cfg, output_path) | |
| def export( | |
| args, | |
| dataset_cfg, | |
| model_cfg, | |
| output_path, | |
| export_dataset_mode=None, | |
| device=torch.device("cuda"), | |
| ): | |
| """Export function.""" | |
| # Choose between normal predictions export or homography adaptation | |
| if dataset_cfg.get("homography_adaptation") is not None: | |
| print("[Info] Export predictions with homography adaptation.") | |
| export_homograpy_adaptation( | |
| args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device | |
| ) | |
| else: | |
| print("[Info] Export predictions normally.") | |
| export_predictions( | |
| args, dataset_cfg, model_cfg, output_path, export_dataset_mode | |
| ) | |
| def main( | |
| args, dataset_cfg, model_cfg, export_dataset_mode=None, device=torch.device("cuda") | |
| ): | |
| """Main function.""" | |
| # Make the output path | |
| output_path = os.path.join(cfg.EXP_PATH, args.exp_name) | |
| if args.mode == "train": | |
| if not os.path.exists(output_path): | |
| os.makedirs(output_path) | |
| print("[Info] Training mode") | |
| print("\t Output path: %s" % output_path) | |
| train(args, dataset_cfg, model_cfg, output_path) | |
| elif args.mode == "export": | |
| # Different output_path in export mode | |
| output_path = os.path.join(cfg.export_dataroot, args.exp_name) | |
| print("[Info] Export mode") | |
| print("\t Output path: %s" % output_path) | |
| export( | |
| args, | |
| dataset_cfg, | |
| model_cfg, | |
| output_path, | |
| export_dataset_mode, | |
| device=device, | |
| ) | |
| else: | |
| raise ValueError("[Error]: Unknown mode: " + args.mode) | |
| def set_random_seed(seed): | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if __name__ == "__main__": | |
| # Parse input arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--mode", type=str, default="train", help="'train' or 'export'." | |
| ) | |
| parser.add_argument( | |
| "--dataset_config", type=str, default=None, help="Path to the dataset config." | |
| ) | |
| parser.add_argument( | |
| "--model_config", type=str, default=None, help="Path to the model config." | |
| ) | |
| parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.") | |
| parser.add_argument( | |
| "--resume", | |
| action="store_true", | |
| default=False, | |
| help="Load a previously trained model.", | |
| ) | |
| parser.add_argument( | |
| "--pretrained", | |
| action="store_true", | |
| default=False, | |
| help="Start training from a pre-trained model.", | |
| ) | |
| parser.add_argument( | |
| "--resume_path", default=None, help="Path from which to resume training." | |
| ) | |
| parser.add_argument( | |
| "--pretrained_path", default=None, help="Path to the pre-trained model." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_name", default=None, help="Name of the checkpoint to use." | |
| ) | |
| parser.add_argument( | |
| "--export_dataset_mode", default=None, help="'train' or 'test'." | |
| ) | |
| parser.add_argument( | |
| "--export_batch_size", default=4, type=int, help="Export batch size." | |
| ) | |
| args = parser.parse_args() | |
| # Check if GPU is available | |
| # Get the model | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| # Check if dataset config and model config is given. | |
| if ( | |
| ((args.dataset_config is None) or (args.model_config is None)) | |
| and (not args.resume) | |
| and (args.mode == "train") | |
| ): | |
| raise ValueError( | |
| "[Error] The dataset config and model config should be given in non-resume mode" | |
| ) | |
| # If resume, check if the resume path has been given | |
| if args.resume and (args.resume_path is None): | |
| raise ValueError("[Error] Missing resume path.") | |
| # [Training] Load the config file. | |
| if args.mode == "train" and (not args.resume): | |
| # Check the pretrained checkpoint_path exists | |
| if args.pretrained: | |
| checkpoint_folder = args.resume_path | |
| checkpoint_path = os.path.join(args.pretrained_path, args.checkpoint_name) | |
| if not os.path.exists(checkpoint_path): | |
| raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) | |
| dataset_cfg = load_config(args.dataset_config) | |
| model_cfg = load_config(args.model_config) | |
| # [resume Training, Test, Export] Load the config file. | |
| elif (args.mode == "train" and args.resume) or (args.mode == "export"): | |
| # Check checkpoint path exists | |
| checkpoint_folder = args.resume_path | |
| checkpoint_path = os.path.join(args.resume_path, args.checkpoint_name) | |
| if not os.path.exists(checkpoint_path): | |
| raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) | |
| # Load model_cfg from checkpoint folder if not provided | |
| if args.model_config is None: | |
| print("[Info] No model config provided. Loading from checkpoint folder.") | |
| model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml") | |
| if not os.path.exists(model_cfg_path): | |
| raise ValueError("[Error] Missing model config in checkpoint path.") | |
| model_cfg = load_config(model_cfg_path) | |
| else: | |
| model_cfg = load_config(args.model_config) | |
| # Load dataset_cfg from checkpoint folder if not provided | |
| if args.dataset_config is None: | |
| print("[Info] No dataset config provided. Loading from checkpoint folder.") | |
| dataset_cfg_path = os.path.join(checkpoint_folder, "dataset_cfg.yaml") | |
| if not os.path.exists(dataset_cfg_path): | |
| raise ValueError("[Error] Missing dataset config in checkpoint path.") | |
| dataset_cfg = load_config(dataset_cfg_path) | |
| else: | |
| dataset_cfg = load_config(args.dataset_config) | |
| # Check the --export_dataset_mode flag | |
| if (args.mode == "export") and (args.export_dataset_mode is None): | |
| raise ValueError("[Error] Empty --export_dataset_mode flag.") | |
| else: | |
| raise ValueError("[Error] Unknown mode: " + args.mode) | |
| # Set the random seed | |
| seed = dataset_cfg.get("random_seed", 0) | |
| set_random_seed(seed) | |
| main( | |
| args, | |
| dataset_cfg, | |
| model_cfg, | |
| export_dataset_mode=args.export_dataset_mode, | |
| device=device, | |
| ) | |