File size: 5,092 Bytes
46a8d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import sys
sys.path.append('vdvae')
import torch
import numpy as np
#from mpi4py import MPI
import socket
import argparse
import os
import json
import subprocess
from hps import Hyperparams, parse_args_and_update_hparams, add_vae_arguments
from utils import (logger,
local_mpi_rank,
mpi_size,
maybe_download,
mpi_rank)
from data import mkdir_p
from contextlib import contextmanager
import torch.distributed as dist
#from apex.optimizers import FusedAdam as AdamW
from vae import VAE
from torch.nn.parallel.distributed import DistributedDataParallel
from train_helpers import restore_params
from image_utils import *
from model_utils import *
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as T
import pickle
import argparse
parser = argparse.ArgumentParser(description='Argument Parser')
parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
parser.add_argument("-bs", "--bs",help="Batch Size",default=30)
args = parser.parse_args()
sub=int(args.sub)
assert sub in [1,2,5,7]
batch_size=int(args.bs)
print('Libs imported')
H = {'image_size': 64, 'image_channels': 3,'seed': 0, 'port': 29500, 'save_dir': './saved_models/test', 'data_root': './', 'desc': 'test', 'hparam_sets': 'imagenet64', 'restore_path': 'imagenet64-iter-1600000-model.th', 'restore_ema_path': 'vdvae/model/imagenet64-iter-1600000-model-ema.th', 'restore_log_path': 'imagenet64-iter-1600000-log.jsonl', 'restore_optimizer_path': 'imagenet64-iter-1600000-opt.th', 'dataset': 'imagenet64', 'ema_rate': 0.999, 'enc_blocks': '64x11,64d2,32x20,32d2,16x9,16d2,8x8,8d2,4x7,4d4,1x5', 'dec_blocks': '1x2,4m1,4x3,8m4,8x7,16m8,16x15,32m16,32x31,64m32,64x12', 'zdim': 16, 'width': 512, 'custom_width_str': '', 'bottleneck_multiple': 0.25, 'no_bias_above': 64, 'scale_encblock': False, 'test_eval': True, 'warmup_iters': 100, 'num_mixtures': 10, 'grad_clip': 220.0, 'skip_threshold': 380.0, 'lr': 0.00015, 'lr_prior': 0.00015, 'wd': 0.01, 'wd_prior': 0.0, 'num_epochs': 10000, 'n_batch': 4, 'adam_beta1': 0.9, 'adam_beta2': 0.9, 'temperature': 1.0, 'iters_per_ckpt': 25000, 'iters_per_print': 1000, 'iters_per_save': 10000, 'iters_per_images': 10000, 'epochs_per_eval': 1, 'epochs_per_probe': None, 'epochs_per_eval_save': 1, 'num_images_visualize': 8, 'num_variables_visualize': 6, 'num_temperatures_visualize': 3, 'mpi_size': 1, 'local_rank': 0, 'rank': 0, 'logdir': './saved_models/test/log'}
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
H = dotdict(H)
H, preprocess_fn = set_up_data(H)
print('Models is Loading')
ema_vae = load_vaes(H)
class batch_generator_external_images(Dataset):
def __init__(self, data_path):
self.data_path = data_path
self.im = np.load(data_path).astype(np.uint8)
def __getitem__(self,idx):
img = Image.fromarray(self.im[idx])
img = T.functional.resize(img,(64,64))
img = torch.tensor(np.array(img)).float()
#img = img/255
#img = img*2 - 1
return img
def __len__(self):
return len(self.im)
image_path = 'data/processed_data/subj{:02d}/nsd_train_stim_sub{}.npy'.format(sub,sub)
train_images = batch_generator_external_images(data_path = image_path)
image_path = 'data/processed_data/subj{:02d}/nsd_test_stim_sub{}.npy'.format(sub,sub)
test_images = batch_generator_external_images(data_path = image_path)
trainloader = DataLoader(train_images,batch_size,shuffle=False)
testloader = DataLoader(test_images,batch_size,shuffle=False)
num_latents = 31
test_latents = []
for i,x in enumerate(testloader):
data_input, target = preprocess_fn(x)
with torch.no_grad():
print(i*batch_size)
activations = ema_vae.encoder.forward(data_input)
px_z, stats = ema_vae.decoder.forward(activations, get_latents=True)
#recons = ema_vae.decoder.out_net.sample(px_z)
batch_latent = []
for i in range(num_latents):
#test_latents[i].append(stats[i]['z'].cpu().numpy())
batch_latent.append(stats[i]['z'].cpu().numpy().reshape(len(data_input),-1))
test_latents.append(np.hstack(batch_latent))
test_latents = np.concatenate(test_latents)
train_latents = []
for i,x in enumerate(trainloader):
data_input, target = preprocess_fn(x)
with torch.no_grad():
print(i*batch_size)
activations = ema_vae.encoder.forward(data_input)
px_z, stats = ema_vae.decoder.forward(activations, get_latents=True)
#recons = ema_vae.decoder.out_net.sample(px_z)
batch_latent = []
for i in range(num_latents):
batch_latent.append(stats[i]['z'].cpu().numpy().reshape(len(data_input),-1))
train_latents.append(np.hstack(batch_latent))
train_latents = np.concatenate(train_latents)
np.savez("data/extracted_features/subj{:02d}/nsd_vdvae_features_31l.npz".format(sub),train_latents=train_latents,test_latents=test_latents)
|