Spaces:
Build error
Build error
import os | |
import sys | |
import time | |
import argparse | |
import traceback | |
import torch | |
import torch.nn as nn | |
from lib import utility | |
from lib.utils import AverageMeter, convert_secs2time | |
os.environ["MKL_THREADING_LAYER"] = "GNU" | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | |
def train(args): | |
device_ids = args.device_ids | |
nprocs = len(device_ids) | |
if nprocs > 1: | |
torch.multiprocessing.spawn( | |
train_worker, args=(nprocs, 1, args), nprocs=nprocs, | |
join=True) | |
elif nprocs == 1: | |
train_worker(device_ids[0], nprocs, 1, args) | |
else: | |
assert False | |
def train_worker(world_rank, world_size, nodes_size, args): | |
# initialize config. | |
config = utility.get_config(args) | |
config.device_id = world_rank if nodes_size == 1 else world_rank % torch.cuda.device_count() | |
# set environment | |
utility.set_environment(config) | |
# initialize instances, such as writer, logger and wandb. | |
if world_rank == 0: | |
config.init_instance() | |
if config.logger is not None: | |
config.logger.info("\n" + "\n".join(["%s: %s" % item for item in config.__dict__.items()])) | |
config.logger.info("Loaded configure file %s: %s" % (config.type, config.id)) | |
# worker communication | |
if world_size > 1: | |
torch.distributed.init_process_group( | |
backend="nccl", init_method="tcp://localhost:23456" if nodes_size == 1 else "env://", | |
rank=world_rank, world_size=world_size) | |
torch.cuda.set_device(config.device) | |
# model | |
net = utility.get_net(config) | |
if world_size > 1: | |
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) | |
net = net.float().to(config.device) | |
net.train(True) | |
if config.ema and world_rank == 0: | |
net_ema = utility.get_net(config) | |
if world_size > 1: | |
net_ema = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net_ema) | |
net_ema = net_ema.float().to(config.device) | |
net_ema.eval() | |
utility.accumulate_net(net_ema, net, 0) | |
else: | |
net_ema = None | |
# multi-GPU training | |
if world_size > 1: | |
net_module = nn.parallel.DistributedDataParallel(net, device_ids=[config.device_id], | |
output_device=config.device_id, find_unused_parameters=True) | |
else: | |
net_module = net | |
criterions = utility.get_criterions(config) | |
optimizer = utility.get_optimizer(config, net_module) | |
scheduler = utility.get_scheduler(config, optimizer) | |
# load pretrain model | |
if args.pretrained_weight is not None: | |
if not os.path.exists(args.pretrained_weight): | |
pretrained_weight = os.path.join(config.work_dir, args.pretrained_weight) | |
else: | |
pretrained_weight = args.pretrained_weight | |
try: | |
checkpoint = torch.load(pretrained_weight) | |
net.load_state_dict(checkpoint["net"], strict=False) | |
if net_ema is not None: | |
net_ema.load_state_dict(checkpoint["net_ema"], strict=False) | |
if config.logger is not None: | |
config.logger.warn("Successed to load pretrain model %s." % pretrained_weight) | |
start_epoch = checkpoint["epoch"] | |
optimizer.load_state_dict(checkpoint["optimizer"]) | |
scheduler.load_state_dict(checkpoint["scheduler"]) | |
except: | |
start_epoch = 0 | |
if config.logger is not None: | |
config.logger.warn("Failed to load pretrain model %s." % pretrained_weight) | |
else: | |
start_epoch = 0 | |
if config.logger is not None: | |
config.logger.info("Loaded network") | |
# data - train, val | |
train_loader = utility.get_dataloader(config, "train", world_rank, world_size) | |
if world_rank == 0: | |
val_loader = utility.get_dataloader(config, "val") | |
if config.logger is not None: | |
config.logger.info("Loaded data") | |
# forward & backward | |
if config.logger is not None: | |
config.logger.info("Optimizer type %s. Start training..." % (config.optimizer)) | |
if not os.path.exists(config.model_dir) and world_rank == 0: | |
os.makedirs(config.model_dir) | |
# training | |
best_metric, best_net = None, None | |
epoch_time, eval_time = AverageMeter(), AverageMeter() | |
for i_epoch, epoch in enumerate(range(config.max_epoch + 1)): | |
try: | |
epoch_start_time = time.time() | |
if epoch >= start_epoch: | |
# forward and backward | |
if epoch != start_epoch: | |
utility.forward_backward(config, train_loader, net_module, net, net_ema, criterions, optimizer, | |
epoch) | |
if world_size > 1: | |
torch.distributed.barrier() | |
# validating | |
if epoch % config.val_epoch == 0 and epoch != 0 and world_rank == 0: | |
eval_start_time = time.time() | |
epoch_nets = {"net": net, "net_ema": net_ema} | |
for net_name, epoch_net in epoch_nets.items(): | |
if epoch_net is None: | |
continue | |
result, metrics = utility.forward(config, val_loader, epoch_net) | |
for k, metric in enumerate(metrics): | |
if config.logger is not None and len(metric) != 0: | |
config.logger.info( | |
"Val_{}/Metric{:3d} in this epoch: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format( | |
net_name, k, metric[0], metric[1], metric[2])) | |
# update best model. | |
cur_metric = metrics[config.key_metric_index][0] | |
if best_metric is None or best_metric > cur_metric: | |
best_metric = cur_metric | |
best_net = epoch_net | |
current_pytorch_model_path = os.path.join(config.model_dir, "best_model.pkl") | |
# current_onnx_model_path = os.path.join(config.model_dir, "train.onnx") | |
utility.save_model( | |
config, | |
epoch, | |
best_net, | |
net_ema, | |
optimizer, | |
scheduler, | |
current_pytorch_model_path) | |
if best_metric is not None: | |
config.logger.info( | |
"Val/Best_Metric%03d in this epoch: %.6f" % (config.key_metric_index, best_metric)) | |
eval_time.update(time.time() - eval_start_time) | |
# saving model | |
if epoch == config.max_epoch and world_rank == 0: | |
current_pytorch_model_path = os.path.join(config.model_dir, "last_model.pkl") | |
# current_onnx_model_path = os.path.join(config.model_dir, "model_epoch_%s.onnx" % epoch) | |
utility.save_model( | |
config, | |
epoch, | |
net, | |
net_ema, | |
optimizer, | |
scheduler, | |
current_pytorch_model_path) | |
if world_size > 1: | |
torch.distributed.barrier() | |
# adjusting learning rate | |
if epoch > 0: | |
scheduler.step() | |
epoch_time.update(time.time() - epoch_start_time) | |
last_time = convert_secs2time(epoch_time.avg * (config.max_epoch - i_epoch), True) | |
if config.logger is not None: | |
config.logger.info( | |
"Train/Epoch: %d/%d, Learning rate decays to %s, " % ( | |
epoch, config.max_epoch, str(scheduler.get_last_lr())) \ | |
+ last_time + 'eval_time: {:4.2f}, '.format(eval_time.avg) + '\n\n') | |
except: | |
traceback.print_exc() | |
config.logger.error("Exception happened in training steps") | |
if config.logger is not None: | |
config.logger.info("Training finished") | |
try: | |
if config.logger is not None and best_metric is not None: | |
new_folder_name = config.folder + '-fin-{:.4f}'.format(best_metric) | |
new_work_dir = os.path.join(config.ckpt_dir, config.data_definition, new_folder_name) | |
os.system('mv {} {}'.format(config.work_dir, new_work_dir)) | |
except: | |
traceback.print_exc() | |
if world_size > 1: | |
torch.distributed.destroy_process_group() | |