brain-diffuser / vdvae /model_utils.py
dineshsai07's picture
Add files using upload-large-folder tool
46a8d8a verified
import torch
import numpy as np
#from mpi4py import MPI
import socket
import argparse
import os
import json
import subprocess
from hps import Hyperparams, parse_args_and_update_hparams, add_vae_arguments
from utils import (logger,
local_mpi_rank,
mpi_size,
maybe_download,
mpi_rank)
from data import mkdir_p
from contextlib import contextmanager
import torch.distributed as dist
#from apex.optimizers import FusedAdam as AdamW
from vae import VAE
from torch.nn.parallel.distributed import DistributedDataParallel
from train_helpers import restore_params
def set_up_hyperparams(s=None):
H = Hyperparams()
parser = argparse.ArgumentParser()
parser = add_vae_arguments(parser)
parse_args_and_update_hparams(H, parser, s=s)
setup_mpi(H)
setup_save_dirs(H)
logprint = logger(H.logdir)
for i, k in enumerate(sorted(H)):
logprint(type='hparam', key=k, value=H[k])
np.random.seed(H.seed)
torch.manual_seed(H.seed)
torch.cuda.manual_seed(H.seed)
logprint('training model', H.desc, 'on', H.dataset)
return H, logprint
def set_up_data(H):
shift_loss = -127.5
scale_loss = 1. / 127.5
#trX, vaX, teX = imagenet64(H.data_root)
H.image_size = 64
H.image_channels = 3
shift = -115.92961967
scale = 1. / 69.37404
#if H.test_eval:
# print('DOING TEST')
# eval_dataset = teX
#else:
# eval_dataset = vaX
shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1)
scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1)
shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1)
scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1)
#train_data = TensorDataset(torch.as_tensor(trX))
#valid_data = TensorDataset(torch.as_tensor(eval_dataset))
#untranspose = False
def preprocess_func(x):
nonlocal shift
nonlocal scale
nonlocal shift_loss
nonlocal scale_loss
'takes in a data example and returns the preprocessed input'
'as well as the input processed for the loss'
#untranspose = False
#if untranspose:
# x[0] = x[0].permute(0, 2, 3, 1)
inp = x.cuda(non_blocking=True).float()
out = inp.clone()
inp.add_(shift).mul_(scale)
out.add_(shift_loss).mul_(scale_loss)
return inp, out
return H, preprocess_func
def load_vaes(H, logprint=None):
ema_vae = VAE(H)
if H.restore_ema_path:
print(f'Restoring ema vae from {H.restore_ema_path}')
restore_params(ema_vae, H.restore_ema_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size)
else:
ema_vae.load_state_dict(vae.state_dict())
ema_vae.requires_grad_(False)
ema_vae = ema_vae.cuda(H.local_rank)
#vae = DistributedDataParallel(vae, device_ids=[H.local_rank], output_device=H.local_rank)
#if len(list(vae.named_parameters())) != len(list(vae.parameters())):
# raise ValueError('Some params are not named. Please name all params.')
#total_params = 0
#for name, p in vae.named_parameters():
# total_params += np.prod(p.shape)
#print(total_params=total_params, readable=f'{total_params:,}')
return ema_vae