brain-diffuser / scripts /evaluate_reconstruction.py
dineshsai07's picture
Add files using upload-large-folder tool
46a8d8a verified
import os
import sys
import numpy as np
import h5py
import scipy.io as spio
import nibabel as nib
import scipy as sp
from PIL import Image
import argparse
parser = argparse.ArgumentParser(description='Argument Parser')
parser.add_argument("-sub", "--sub",help="Subject Number",default=1)
args = parser.parse_args()
sub=int(args.sub)
assert sub in [1,2,5,7]
from scipy.stats import pearsonr,binom,linregress
import numpy as np
def pairwise_corr_all(ground_truth, predictions):
r = np.corrcoef(ground_truth, predictions)#cosine_similarity(ground_truth, predictions)#
r = r[:len(ground_truth), len(ground_truth):] # rows: groundtruth, columns: predicitons
#print(r.shape)
# congruent pairs are on diagonal
congruents = np.diag(r)
#print(congruents)
# for each column (predicition) we should count the number of rows (groundtruth) that the value is lower than the congruent (e.g. success).
success = r < congruents
success_cnt = np.sum(success, 0)
# note: diagonal of 'success' is always zero so we can discard it. That's why we divide by len-1
perf = np.mean(success_cnt) / (len(ground_truth)-1)
p = 1 - binom.cdf(perf*len(ground_truth)*(len(ground_truth)-1), len(ground_truth)*(len(ground_truth)-1), 0.5)
return perf, p
net_list = [
('inceptionv3','avgpool'),
('clip','final'),
('alexnet',2),
('alexnet',5),
('efficientnet','avgpool'),
('swav','avgpool')
]
feats_dir = 'data/eval_features/subj{:02d}'.format(sub)
test_dir = 'data/eval_features/test_images'
num_test = 982
distance_fn = sp.spatial.distance.correlation
pairwise_corrs = []
for (net_name,layer) in net_list:
file_name = '{}/{}_{}.npy'.format(test_dir,net_name,layer)
gt_feat = np.load(file_name)
file_name = '{}/{}_{}.npy'.format(feats_dir,net_name,layer)
eval_feat = np.load(file_name)
gt_feat = gt_feat.reshape((len(gt_feat),-1))
eval_feat = eval_feat.reshape((len(eval_feat),-1))
print(net_name,layer)
if net_name in ['efficientnet','swav']:
print('distance: ',np.array([distance_fn(gt_feat[i],eval_feat[i]) for i in range(num_test)]).mean())
else:
pairwise_corrs.append(pairwise_corr_all(gt_feat[:num_test],eval_feat[:num_test])[0])
print('pairwise corr: ',pairwise_corrs[-1])
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity as ssim
ssim_list = []
pixcorr_list = []
for i in range(982):
gen_image = Image.open('results/versatile_diffusion/subj{:02d}/{}.png'.format(sub,i)).resize((425,425))
gt_image = Image.open('data/nsddata_stimuli/test_images/{}.png'.format(i))
gen_image = np.array(gen_image)/255.0
gt_image = np.array(gt_image)/255.0
pixcorr_res = np.corrcoef(gt_image.reshape(1,-1), gen_image.reshape(1,-1))[0,1]
pixcorr_list.append(pixcorr_res)
gen_image = rgb2gray(gen_image)
gt_image = rgb2gray(gt_image)
ssim_res = ssim(gen_image, gt_image, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0)
ssim_list.append(ssim_res)
ssim_list = np.array(ssim_list)
pixcorr_list = np.array(pixcorr_list)
print('PixCorr: {}'.format(pixcorr_list.mean()))
print('SSIM: {}'.format(ssim_list.mean()))