File size: 4,400 Bytes
46a8d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import sys
import numpy as np
import h5py
import scipy.io as spio
import nibabel as nib
import torch
import torchvision
import torchvision.models as tvmodels
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from PIL import Image
import clip
import skimage.io as sio
from skimage import data, img_as_float
from skimage.transform import resize as imresize
from skimage.metrics import structural_similarity as ssim
import scipy as sp
import argparse
parser = argparse.ArgumentParser(description='Argument Parser')
parser.add_argument("-sub", "--sub", help="Subject Number", default=1)
args = parser.parse_args()
sub = int(args.sub)
assert sub in [0, 1, 2, 5, 7]
images_dir = 'data/nsddata_stimuli/test_images'
feats_dir = 'data/eval_features/test_images'
if sub in [1, 2, 5, 7]:
feats_dir = f'data/eval_features/subj{sub:02d}'
images_dir = f'results/versatile_diffusion/subj{sub:02d}'
if not os.path.exists(feats_dir):
os.makedirs(feats_dir)
class batch_generator_external_images(Dataset):
def __init__(self, data_path='', prefix='', net_name='clip'):
self.data_path = data_path
self.prefix = prefix
self.net_name = net_name
if self.net_name == 'clip':
self.normalize = transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711]
)
else:
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
self.num_test = 982
def __getitem__(self, idx):
img = Image.open(f'{self.data_path}/{self.prefix}{idx}.png')
img = T.functional.resize(img, (224, 224))
img = T.functional.to_tensor(img).float()
img = self.normalize(img)
return img
def __len__(self):
return self.num_test
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
global feat_list
feat_list = []
def fn(module, inputs, outputs):
feat_list.append(outputs.cpu().numpy())
net_list = [
('inceptionv3', 'avgpool'),
('clip', 'final'),
('alexnet', 2),
('alexnet', 5),
('efficientnet', 'avgpool'),
('swav', 'avgpool')
]
batchsize = 64
for (net_name, layer) in net_list:
feat_list = []
print(net_name, layer)
dataset = batch_generator_external_images(data_path=images_dir, net_name=net_name, prefix='')
loader = DataLoader(dataset, batchsize, shuffle=False)
if net_name == 'inceptionv3':
net = tvmodels.inception_v3(pretrained=True)
if layer == 'avgpool':
net.avgpool.register_forward_hook(fn)
elif layer == 'lastconv':
net.Mixed_7c.register_forward_hook(fn)
elif net_name == 'alexnet':
net = tvmodels.alexnet(pretrained=True)
if layer == 2:
net.features[4].register_forward_hook(fn)
elif layer == 5:
net.features[11].register_forward_hook(fn)
elif layer == 7:
net.classifier[5].register_forward_hook(fn)
elif net_name == 'clip':
model, _ = clip.load("ViT-L/14", device=device)
net = model.visual.to(torch.float32)
if layer == 7:
net.transformer.resblocks[7].register_forward_hook(fn)
elif layer == 12:
net.transformer.resblocks[12].register_forward_hook(fn)
elif layer == 'final':
net.register_forward_hook(fn)
elif net_name == 'efficientnet':
net = tvmodels.efficientnet_b1(weights='IMAGENET1K_V1')
net.avgpool.register_forward_hook(fn)
elif net_name == 'swav':
net = torch.hub.load('facebookresearch/swav:main', 'resnet50')
net.avgpool.register_forward_hook(fn)
net.eval()
net = net.to(device)
with torch.no_grad():
for i, x in enumerate(loader):
print(i * batchsize)
x = x.to(device)
_ = net(x)
if net_name == 'clip':
if layer == 7 or layer == 12:
feat_list = np.concatenate(feat_list, axis=1).transpose((1, 0, 2))
else:
feat_list = np.concatenate(feat_list)
else:
feat_list = np.concatenate(feat_list)
file_name = f'{feats_dir}/{net_name}_{layer}.npy'
np.save(file_name, feat_list)
|