brain-diffuser / scripts /eval_extract_features.py
dineshsai07's picture
Add files using upload-large-folder tool
46a8d8a verified
import os
import sys
import numpy as np
import h5py
import scipy.io as spio
import nibabel as nib
import torch
import torchvision
import torchvision.models as tvmodels
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from PIL import Image
import clip
import skimage.io as sio
from skimage import data, img_as_float
from skimage.transform import resize as imresize
from skimage.metrics import structural_similarity as ssim
import scipy as sp
import argparse
parser = argparse.ArgumentParser(description='Argument Parser')
parser.add_argument("-sub", "--sub", help="Subject Number", default=1)
args = parser.parse_args()
sub = int(args.sub)
assert sub in [0, 1, 2, 5, 7]
images_dir = 'data/nsddata_stimuli/test_images'
feats_dir = 'data/eval_features/test_images'
if sub in [1, 2, 5, 7]:
feats_dir = f'data/eval_features/subj{sub:02d}'
images_dir = f'results/versatile_diffusion/subj{sub:02d}'
if not os.path.exists(feats_dir):
os.makedirs(feats_dir)
class batch_generator_external_images(Dataset):
def __init__(self, data_path='', prefix='', net_name='clip'):
self.data_path = data_path
self.prefix = prefix
self.net_name = net_name
if self.net_name == 'clip':
self.normalize = transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711]
)
else:
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
self.num_test = 982
def __getitem__(self, idx):
img = Image.open(f'{self.data_path}/{self.prefix}{idx}.png')
img = T.functional.resize(img, (224, 224))
img = T.functional.to_tensor(img).float()
img = self.normalize(img)
return img
def __len__(self):
return self.num_test
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
global feat_list
feat_list = []
def fn(module, inputs, outputs):
feat_list.append(outputs.cpu().numpy())
net_list = [
('inceptionv3', 'avgpool'),
('clip', 'final'),
('alexnet', 2),
('alexnet', 5),
('efficientnet', 'avgpool'),
('swav', 'avgpool')
]
batchsize = 64
for (net_name, layer) in net_list:
feat_list = []
print(net_name, layer)
dataset = batch_generator_external_images(data_path=images_dir, net_name=net_name, prefix='')
loader = DataLoader(dataset, batchsize, shuffle=False)
if net_name == 'inceptionv3':
net = tvmodels.inception_v3(pretrained=True)
if layer == 'avgpool':
net.avgpool.register_forward_hook(fn)
elif layer == 'lastconv':
net.Mixed_7c.register_forward_hook(fn)
elif net_name == 'alexnet':
net = tvmodels.alexnet(pretrained=True)
if layer == 2:
net.features[4].register_forward_hook(fn)
elif layer == 5:
net.features[11].register_forward_hook(fn)
elif layer == 7:
net.classifier[5].register_forward_hook(fn)
elif net_name == 'clip':
model, _ = clip.load("ViT-L/14", device=device)
net = model.visual.to(torch.float32)
if layer == 7:
net.transformer.resblocks[7].register_forward_hook(fn)
elif layer == 12:
net.transformer.resblocks[12].register_forward_hook(fn)
elif layer == 'final':
net.register_forward_hook(fn)
elif net_name == 'efficientnet':
net = tvmodels.efficientnet_b1(weights='IMAGENET1K_V1')
net.avgpool.register_forward_hook(fn)
elif net_name == 'swav':
net = torch.hub.load('facebookresearch/swav:main', 'resnet50')
net.avgpool.register_forward_hook(fn)
net.eval()
net = net.to(device)
with torch.no_grad():
for i, x in enumerate(loader):
print(i * batchsize)
x = x.to(device)
_ = net(x)
if net_name == 'clip':
if layer == 7 or layer == 12:
feat_list = np.concatenate(feat_list, axis=1).transpose((1, 0, 2))
else:
feat_list = np.concatenate(feat_list)
else:
feat_list = np.concatenate(feat_list)
file_name = f'{feats_dir}/{net_name}_{layer}.npy'
np.save(file_name, feat_list)