newTryOn / models /STAR /trainer.py
amanSethSmava
new commit
6d314be
raw
history blame
8.66 kB
import os
import sys
import time
import argparse
import traceback
import torch
import torch.nn as nn
from lib import utility
from lib.utils import AverageMeter, convert_secs2time
os.environ["MKL_THREADING_LAYER"] = "GNU"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
def train(args):
device_ids = args.device_ids
nprocs = len(device_ids)
if nprocs > 1:
torch.multiprocessing.spawn(
train_worker, args=(nprocs, 1, args), nprocs=nprocs,
join=True)
elif nprocs == 1:
train_worker(device_ids[0], nprocs, 1, args)
else:
assert False
def train_worker(world_rank, world_size, nodes_size, args):
# initialize config.
config = utility.get_config(args)
config.device_id = world_rank if nodes_size == 1 else world_rank % torch.cuda.device_count()
# set environment
utility.set_environment(config)
# initialize instances, such as writer, logger and wandb.
if world_rank == 0:
config.init_instance()
if config.logger is not None:
config.logger.info("\n" + "\n".join(["%s: %s" % item for item in config.__dict__.items()]))
config.logger.info("Loaded configure file %s: %s" % (config.type, config.id))
# worker communication
if world_size > 1:
torch.distributed.init_process_group(
backend="nccl", init_method="tcp://localhost:23456" if nodes_size == 1 else "env://",
rank=world_rank, world_size=world_size)
torch.cuda.set_device(config.device)
# model
net = utility.get_net(config)
if world_size > 1:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
net = net.float().to(config.device)
net.train(True)
if config.ema and world_rank == 0:
net_ema = utility.get_net(config)
if world_size > 1:
net_ema = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net_ema)
net_ema = net_ema.float().to(config.device)
net_ema.eval()
utility.accumulate_net(net_ema, net, 0)
else:
net_ema = None
# multi-GPU training
if world_size > 1:
net_module = nn.parallel.DistributedDataParallel(net, device_ids=[config.device_id],
output_device=config.device_id, find_unused_parameters=True)
else:
net_module = net
criterions = utility.get_criterions(config)
optimizer = utility.get_optimizer(config, net_module)
scheduler = utility.get_scheduler(config, optimizer)
# load pretrain model
if args.pretrained_weight is not None:
if not os.path.exists(args.pretrained_weight):
pretrained_weight = os.path.join(config.work_dir, args.pretrained_weight)
else:
pretrained_weight = args.pretrained_weight
try:
checkpoint = torch.load(pretrained_weight)
net.load_state_dict(checkpoint["net"], strict=False)
if net_ema is not None:
net_ema.load_state_dict(checkpoint["net_ema"], strict=False)
if config.logger is not None:
config.logger.warn("Successed to load pretrain model %s." % pretrained_weight)
start_epoch = checkpoint["epoch"]
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
except:
start_epoch = 0
if config.logger is not None:
config.logger.warn("Failed to load pretrain model %s." % pretrained_weight)
else:
start_epoch = 0
if config.logger is not None:
config.logger.info("Loaded network")
# data - train, val
train_loader = utility.get_dataloader(config, "train", world_rank, world_size)
if world_rank == 0:
val_loader = utility.get_dataloader(config, "val")
if config.logger is not None:
config.logger.info("Loaded data")
# forward & backward
if config.logger is not None:
config.logger.info("Optimizer type %s. Start training..." % (config.optimizer))
if not os.path.exists(config.model_dir) and world_rank == 0:
os.makedirs(config.model_dir)
# training
best_metric, best_net = None, None
epoch_time, eval_time = AverageMeter(), AverageMeter()
for i_epoch, epoch in enumerate(range(config.max_epoch + 1)):
try:
epoch_start_time = time.time()
if epoch >= start_epoch:
# forward and backward
if epoch != start_epoch:
utility.forward_backward(config, train_loader, net_module, net, net_ema, criterions, optimizer,
epoch)
if world_size > 1:
torch.distributed.barrier()
# validating
if epoch % config.val_epoch == 0 and epoch != 0 and world_rank == 0:
eval_start_time = time.time()
epoch_nets = {"net": net, "net_ema": net_ema}
for net_name, epoch_net in epoch_nets.items():
if epoch_net is None:
continue
result, metrics = utility.forward(config, val_loader, epoch_net)
for k, metric in enumerate(metrics):
if config.logger is not None and len(metric) != 0:
config.logger.info(
"Val_{}/Metric{:3d} in this epoch: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format(
net_name, k, metric[0], metric[1], metric[2]))
# update best model.
cur_metric = metrics[config.key_metric_index][0]
if best_metric is None or best_metric > cur_metric:
best_metric = cur_metric
best_net = epoch_net
current_pytorch_model_path = os.path.join(config.model_dir, "best_model.pkl")
# current_onnx_model_path = os.path.join(config.model_dir, "train.onnx")
utility.save_model(
config,
epoch,
best_net,
net_ema,
optimizer,
scheduler,
current_pytorch_model_path)
if best_metric is not None:
config.logger.info(
"Val/Best_Metric%03d in this epoch: %.6f" % (config.key_metric_index, best_metric))
eval_time.update(time.time() - eval_start_time)
# saving model
if epoch == config.max_epoch and world_rank == 0:
current_pytorch_model_path = os.path.join(config.model_dir, "last_model.pkl")
# current_onnx_model_path = os.path.join(config.model_dir, "model_epoch_%s.onnx" % epoch)
utility.save_model(
config,
epoch,
net,
net_ema,
optimizer,
scheduler,
current_pytorch_model_path)
if world_size > 1:
torch.distributed.barrier()
# adjusting learning rate
if epoch > 0:
scheduler.step()
epoch_time.update(time.time() - epoch_start_time)
last_time = convert_secs2time(epoch_time.avg * (config.max_epoch - i_epoch), True)
if config.logger is not None:
config.logger.info(
"Train/Epoch: %d/%d, Learning rate decays to %s, " % (
epoch, config.max_epoch, str(scheduler.get_last_lr())) \
+ last_time + 'eval_time: {:4.2f}, '.format(eval_time.avg) + '\n\n')
except:
traceback.print_exc()
config.logger.error("Exception happened in training steps")
if config.logger is not None:
config.logger.info("Training finished")
try:
if config.logger is not None and best_metric is not None:
new_folder_name = config.folder + '-fin-{:.4f}'.format(best_metric)
new_work_dir = os.path.join(config.ckpt_dir, config.data_definition, new_folder_name)
os.system('mv {} {}'.format(config.work_dir, new_work_dir))
except:
traceback.print_exc()
if world_size > 1:
torch.distributed.destroy_process_group()