|
import os |
|
import sys |
|
import numpy as np |
|
import h5py |
|
import scipy.io as spio |
|
import nibabel as nib |
|
import scipy as sp |
|
from PIL import Image |
|
|
|
|
|
|
|
import argparse |
|
parser = argparse.ArgumentParser(description='Argument Parser') |
|
parser.add_argument("-sub", "--sub",help="Subject Number",default=1) |
|
args = parser.parse_args() |
|
sub=int(args.sub) |
|
assert sub in [1,2,5,7] |
|
|
|
|
|
from scipy.stats import pearsonr,binom,linregress |
|
import numpy as np |
|
def pairwise_corr_all(ground_truth, predictions): |
|
r = np.corrcoef(ground_truth, predictions) |
|
r = r[:len(ground_truth), len(ground_truth):] |
|
|
|
|
|
congruents = np.diag(r) |
|
|
|
|
|
|
|
success = r < congruents |
|
success_cnt = np.sum(success, 0) |
|
|
|
|
|
perf = np.mean(success_cnt) / (len(ground_truth)-1) |
|
p = 1 - binom.cdf(perf*len(ground_truth)*(len(ground_truth)-1), len(ground_truth)*(len(ground_truth)-1), 0.5) |
|
|
|
return perf, p |
|
|
|
|
|
net_list = [ |
|
('inceptionv3','avgpool'), |
|
('clip','final'), |
|
('alexnet',2), |
|
('alexnet',5), |
|
('efficientnet','avgpool'), |
|
('swav','avgpool') |
|
] |
|
|
|
feats_dir = 'data/eval_features/subj{:02d}'.format(sub) |
|
test_dir = 'data/eval_features/test_images' |
|
num_test = 982 |
|
distance_fn = sp.spatial.distance.correlation |
|
pairwise_corrs = [] |
|
for (net_name,layer) in net_list: |
|
file_name = '{}/{}_{}.npy'.format(test_dir,net_name,layer) |
|
gt_feat = np.load(file_name) |
|
|
|
file_name = '{}/{}_{}.npy'.format(feats_dir,net_name,layer) |
|
eval_feat = np.load(file_name) |
|
|
|
gt_feat = gt_feat.reshape((len(gt_feat),-1)) |
|
eval_feat = eval_feat.reshape((len(eval_feat),-1)) |
|
|
|
print(net_name,layer) |
|
if net_name in ['efficientnet','swav']: |
|
print('distance: ',np.array([distance_fn(gt_feat[i],eval_feat[i]) for i in range(num_test)]).mean()) |
|
else: |
|
pairwise_corrs.append(pairwise_corr_all(gt_feat[:num_test],eval_feat[:num_test])[0]) |
|
print('pairwise corr: ',pairwise_corrs[-1]) |
|
|
|
from skimage.color import rgb2gray |
|
from skimage.metrics import structural_similarity as ssim |
|
|
|
ssim_list = [] |
|
pixcorr_list = [] |
|
for i in range(982): |
|
gen_image = Image.open('results/versatile_diffusion/subj{:02d}/{}.png'.format(sub,i)).resize((425,425)) |
|
gt_image = Image.open('data/nsddata_stimuli/test_images/{}.png'.format(i)) |
|
gen_image = np.array(gen_image)/255.0 |
|
gt_image = np.array(gt_image)/255.0 |
|
pixcorr_res = np.corrcoef(gt_image.reshape(1,-1), gen_image.reshape(1,-1))[0,1] |
|
pixcorr_list.append(pixcorr_res) |
|
gen_image = rgb2gray(gen_image) |
|
gt_image = rgb2gray(gt_image) |
|
ssim_res = ssim(gen_image, gt_image, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0) |
|
ssim_list.append(ssim_res) |
|
|
|
ssim_list = np.array(ssim_list) |
|
pixcorr_list = np.array(pixcorr_list) |
|
print('PixCorr: {}'.format(pixcorr_list.mean())) |
|
print('SSIM: {}'.format(ssim_list.mean())) |
|
|
|
|