import os import sys import numpy as np import h5py import scipy.io as spio import nibabel as nib import torch import torchvision import torchvision.models as tvmodels import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset import torchvision.transforms as T from PIL import Image import clip import skimage.io as sio from skimage import data, img_as_float from skimage.transform import resize as imresize from skimage.metrics import structural_similarity as ssim import scipy as sp import argparse parser = argparse.ArgumentParser(description='Argument Parser') parser.add_argument("-sub", "--sub", help="Subject Number", default=1) args = parser.parse_args() sub = int(args.sub) assert sub in [0, 1, 2, 5, 7] images_dir = 'data/nsddata_stimuli/test_images' feats_dir = 'data/eval_features/test_images' if sub in [1, 2, 5, 7]: feats_dir = f'data/eval_features/subj{sub:02d}' images_dir = f'results/versatile_diffusion/subj{sub:02d}' if not os.path.exists(feats_dir): os.makedirs(feats_dir) class batch_generator_external_images(Dataset): def __init__(self, data_path='', prefix='', net_name='clip'): self.data_path = data_path self.prefix = prefix self.net_name = net_name if self.net_name == 'clip': self.normalize = transforms.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] ) else: self.normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) self.num_test = 982 def __getitem__(self, idx): img = Image.open(f'{self.data_path}/{self.prefix}{idx}.png') img = T.functional.resize(img, (224, 224)) img = T.functional.to_tensor(img).float() img = self.normalize(img) return img def __len__(self): return self.num_test # Set device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") global feat_list feat_list = [] def fn(module, inputs, outputs): feat_list.append(outputs.cpu().numpy()) net_list = [ ('inceptionv3', 'avgpool'), ('clip', 'final'), ('alexnet', 2), ('alexnet', 5), ('efficientnet', 'avgpool'), ('swav', 'avgpool') ] batchsize = 64 for (net_name, layer) in net_list: feat_list = [] print(net_name, layer) dataset = batch_generator_external_images(data_path=images_dir, net_name=net_name, prefix='') loader = DataLoader(dataset, batchsize, shuffle=False) if net_name == 'inceptionv3': net = tvmodels.inception_v3(pretrained=True) if layer == 'avgpool': net.avgpool.register_forward_hook(fn) elif layer == 'lastconv': net.Mixed_7c.register_forward_hook(fn) elif net_name == 'alexnet': net = tvmodels.alexnet(pretrained=True) if layer == 2: net.features[4].register_forward_hook(fn) elif layer == 5: net.features[11].register_forward_hook(fn) elif layer == 7: net.classifier[5].register_forward_hook(fn) elif net_name == 'clip': model, _ = clip.load("ViT-L/14", device=device) net = model.visual.to(torch.float32) if layer == 7: net.transformer.resblocks[7].register_forward_hook(fn) elif layer == 12: net.transformer.resblocks[12].register_forward_hook(fn) elif layer == 'final': net.register_forward_hook(fn) elif net_name == 'efficientnet': net = tvmodels.efficientnet_b1(weights='IMAGENET1K_V1') net.avgpool.register_forward_hook(fn) elif net_name == 'swav': net = torch.hub.load('facebookresearch/swav:main', 'resnet50') net.avgpool.register_forward_hook(fn) net.eval() net = net.to(device) with torch.no_grad(): for i, x in enumerate(loader): print(i * batchsize) x = x.to(device) _ = net(x) if net_name == 'clip': if layer == 7 or layer == 12: feat_list = np.concatenate(feat_list, axis=1).transpose((1, 0, 2)) else: feat_list = np.concatenate(feat_list) else: feat_list = np.concatenate(feat_list) file_name = f'{feats_dir}/{net_name}_{layer}.npy' np.save(file_name, feat_list)