|
import torch |
|
import numpy as np |
|
|
|
import socket |
|
import argparse |
|
import os |
|
import json |
|
import subprocess |
|
from hps import Hyperparams, parse_args_and_update_hparams, add_vae_arguments |
|
from utils import (logger, |
|
local_mpi_rank, |
|
mpi_size, |
|
maybe_download, |
|
mpi_rank) |
|
from data import mkdir_p |
|
from contextlib import contextmanager |
|
import torch.distributed as dist |
|
|
|
from vae import VAE |
|
from torch.nn.parallel.distributed import DistributedDataParallel |
|
from train_helpers import restore_params |
|
|
|
def set_up_hyperparams(s=None): |
|
H = Hyperparams() |
|
parser = argparse.ArgumentParser() |
|
parser = add_vae_arguments(parser) |
|
parse_args_and_update_hparams(H, parser, s=s) |
|
setup_mpi(H) |
|
setup_save_dirs(H) |
|
logprint = logger(H.logdir) |
|
for i, k in enumerate(sorted(H)): |
|
logprint(type='hparam', key=k, value=H[k]) |
|
np.random.seed(H.seed) |
|
torch.manual_seed(H.seed) |
|
torch.cuda.manual_seed(H.seed) |
|
logprint('training model', H.desc, 'on', H.dataset) |
|
return H, logprint |
|
|
|
def set_up_data(H): |
|
shift_loss = -127.5 |
|
scale_loss = 1. / 127.5 |
|
|
|
|
|
H.image_size = 64 |
|
H.image_channels = 3 |
|
shift = -115.92961967 |
|
scale = 1. / 69.37404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1) |
|
scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1) |
|
shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1) |
|
scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1) |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_func(x): |
|
nonlocal shift |
|
nonlocal scale |
|
nonlocal shift_loss |
|
nonlocal scale_loss |
|
'takes in a data example and returns the preprocessed input' |
|
'as well as the input processed for the loss' |
|
|
|
|
|
|
|
inp = x.cuda(non_blocking=True).float() |
|
out = inp.clone() |
|
inp.add_(shift).mul_(scale) |
|
out.add_(shift_loss).mul_(scale_loss) |
|
return inp, out |
|
|
|
return H, preprocess_func |
|
|
|
def load_vaes(H, logprint=None): |
|
|
|
ema_vae = VAE(H) |
|
if H.restore_ema_path: |
|
print(f'Restoring ema vae from {H.restore_ema_path}') |
|
restore_params(ema_vae, H.restore_ema_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size) |
|
else: |
|
ema_vae.load_state_dict(vae.state_dict()) |
|
ema_vae.requires_grad_(False) |
|
ema_vae = ema_vae.cuda(H.local_rank) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ema_vae |