import os import sys import json import glob import argparse from easydict import EasyDict as edict import torch import torch.multiprocessing as mp import numpy as np import random from trellis import models, datasets, trainers from trellis.utils.dist_utils import setup_dist def find_ckpt(cfg): # Load checkpoint cfg['load_ckpt'] = None if cfg.load_dir != '': if cfg.ckpt == 'latest': files = glob.glob(os.path.join(cfg.load_dir, 'ckpts', 'misc_*.pt')) if len(files) != 0: cfg.load_ckpt = max([ int(os.path.basename(f).split('step')[-1].split('.')[0]) for f in files ]) elif cfg.ckpt == 'none': cfg.load_ckpt = None else: cfg.load_ckpt = int(cfg.ckpt) return cfg def setup_rng(rank): torch.manual_seed(rank) torch.cuda.manual_seed_all(rank) np.random.seed(rank) random.seed(rank) def get_model_summary(model): model_summary = 'Parameters:\n' model_summary += '=' * 128 + '\n' model_summary += f'{"Name":<{72}}{"Shape":<{32}}{"Type":<{16}}{"Grad"}\n' num_params = 0 num_trainable_params = 0 for name, param in model.named_parameters(): model_summary += f'{name:<{72}}{str(param.shape):<{32}}{str(param.dtype):<{16}}{param.requires_grad}\n' num_params += param.numel() if param.requires_grad: num_trainable_params += param.numel() model_summary += '\n' model_summary += f'Number of parameters: {num_params}\n' model_summary += f'Number of trainable parameters: {num_trainable_params}\n' return model_summary def main(local_rank, cfg): # Set up distributed training rank = cfg.node_rank * cfg.num_gpus + local_rank world_size = cfg.num_nodes * cfg.num_gpus if world_size > 1: setup_dist(rank, local_rank, world_size, cfg.master_addr, cfg.master_port) # Seed rngs setup_rng(rank) # Load data dataset = getattr(datasets, cfg.dataset.name)(cfg.data_dir, **cfg.dataset.args) # Build model model_dict = { name: getattr(models, model.name)(**model.args).cuda() for name, model in cfg.models.items() } # Model summary if rank == 0: for name, backbone in model_dict.items(): model_summary = get_model_summary(backbone) print(f'\n\nBackbone: {name}\n' + model_summary) with open(os.path.join(cfg.output_dir, f'{name}_model_summary.txt'), 'w') as fp: print(model_summary, file=fp) # Build trainer trainer = getattr(trainers, cfg.trainer.name)(model_dict, dataset, **cfg.trainer.args, output_dir=cfg.output_dir, load_dir=cfg.load_dir, step=cfg.load_ckpt) # Train if not cfg.tryrun: if cfg.profile: trainer.profile() else: trainer.run() if __name__ == '__main__': # Arguments and config parser = argparse.ArgumentParser() ## config parser.add_argument('--config', type=str, required=True, help='Experiment config file') ## io and resume parser.add_argument('--output_dir', type=str, required=True, help='Output directory') parser.add_argument('--load_dir', type=str, default='', help='Load directory, default to output_dir') parser.add_argument('--ckpt', type=str, default='latest', help='Checkpoint step to resume training, default to latest') parser.add_argument('--data_dir', type=str, default='./data/', help='Data directory') parser.add_argument('--auto_retry', type=int, default=3, help='Number of retries on error') ## dubug parser.add_argument('--tryrun', action='store_true', help='Try run without training') parser.add_argument('--profile', action='store_true', help='Profile training') ## multi-node and multi-gpu parser.add_argument('--num_nodes', type=int, default=1, help='Number of nodes') parser.add_argument('--node_rank', type=int, default=0, help='Node rank') parser.add_argument('--num_gpus', type=int, default=-1, help='Number of GPUs per node, default to all') parser.add_argument('--master_addr', type=str, default='localhost', help='Master address for distributed training') parser.add_argument('--master_port', type=str, default='12345', help='Port for distributed training') opt = parser.parse_args() opt.load_dir = opt.load_dir if opt.load_dir != '' else opt.output_dir opt.num_gpus = torch.cuda.device_count() if opt.num_gpus == -1 else opt.num_gpus ## Load config config = json.load(open(opt.config, 'r')) ## Combine arguments and config cfg = edict() cfg.update(opt.__dict__) cfg.update(config) print('\n\nConfig:') print('=' * 80) print(json.dumps(cfg.__dict__, indent=4)) # Prepare output directory if cfg.node_rank == 0: os.makedirs(cfg.output_dir, exist_ok=True) ## Save command and config with open(os.path.join(cfg.output_dir, 'command.txt'), 'w') as fp: print(' '.join(['python'] + sys.argv), file=fp) with open(os.path.join(cfg.output_dir, 'config.json'), 'w') as fp: json.dump(config, fp, indent=4) # Run if cfg.auto_retry == 0: cfg = find_ckpt(cfg) if cfg.num_gpus > 1: mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True) else: main(0, cfg) else: for rty in range(cfg.auto_retry): try: cfg = find_ckpt(cfg) if cfg.num_gpus > 1: mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True) else: main(0, cfg) break except Exception as e: print(f'Error: {e}') print(f'Retrying ({rty + 1}/{cfg.auto_retry})...')