SAM-DiffSR / sam_diffsr /tools /caculate_iqa.py
Traly's picture
init
193c713
import os
import ssl
from os.path import join
from pathlib import Path
from statistics import mean
parent_path = Path(__file__).absolute().parent.parent
parent_path = os.path.abspath(parent_path)
os.environ["CURL_CA_BUNDLE"] = ""
ssl._create_default_https_context = ssl._create_unverified_context
cache_path = os.path.join(parent_path, 'cache')
os.environ["HF_DATASETS_CACHE"] = cache_path
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["torch_HOME"] = cache_path
import PIL
import numpy as np
import pandas as pd
import pyiqa
import torch
from PIL import Image
from tqdm import tqdm
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
metric_dict = {
'psnr-Y': pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr'),
'ssim': pyiqa.create_metric('ssim', color_space='ycbcr'),
'fid': pyiqa.create_metric('fid'),
}
def load_img(path, target_size=None):
image = Image.open(path).convert("RGB")
if target_size:
h, w = target_size
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
def eval_img_IQA(gt_dir, sr_dir, excel_path, metric_list, exp_name, data_name):
gt_img_list = os.listdir(gt_dir)
iqa_result = {}
for metric in metric_list:
iqa_metric = metric_dict[metric].to(device)
score_fr_list = []
if metric == 'fid':
score_fr = iqa_metric(sr_dir, gt_dir)
iqa_result[metric] = float(score_fr)
print(f'{metric}: {float(score_fr)}')
else:
for img_name in tqdm(gt_img_list):
base_name = img_name.split('.')[0]
sr_img_name = f'{base_name}.png'
gt_img_path = join(gt_dir, img_name)
sr_img_path = join(sr_dir, sr_img_name)
if not os.path.exists(sr_img_path):
print(f'File not exist: {sr_img_path}')
continue
gt_img = load_img(gt_img_path, target_size=None)
target_size = gt_img.shape[2:]
sr_img = load_img(sr_img_path, target_size=target_size)
score_fr = iqa_metric(sr_img, gt_img)
if score_fr.shape == (1,):
score_fr = score_fr[0]
if isinstance(score_fr, torch.Tensor):
score_fr = float(score_fr.cpu().numpy())
else:
score_fr = float(score_fr)
score_fr_list.append(score_fr)
mean_score = mean(score_fr_list)
iqa_result[metric] = float(mean_score)
print(f'{metric}: {mean_score}')
if os.path.exists(excel_path):
df = pd.read_excel(excel_path)
else:
df = pd.DataFrame(columns=['exp'])
new_index = len(df.index)
exp_name = int(exp_name)
if exp_name in df['exp'].to_list():
new_index = df[df['exp'] == exp_name].index.tolist()[0]
else:
df.loc[new_index, 'exp'] = exp_name
for index, metric in enumerate(metric_list):
df_metric = f'{data_name}-{metric}'
if df_metric not in df.columns.tolist():
df[df_metric] = ''
df.loc[new_index, df_metric] = iqa_result[metric]
df.sort_values(by='exp', inplace=True)
df.to_excel(excel_path, startcol=0, index=False)
def main():
epoch = 400000
add_name = ''
exp_root = '/home/ma-user/work/code/SRDiff-main/checkpoints'
model_type_list = ['diffsr_df2k4x_sam-pl_qs-zero']
metric_list = ['psnr-Y', 'ssim', 'fid']
benchmark_name_list = ['test_Set5', 'test_Set14', 'test_Urban100', 'test_Manga109', 'test_BSDS100']
# if benchmark:
for model_type in model_type_list:
excel_path = join(exp_root, model_type, f'IQA-val-{model_type}.xls')
for benchmark_name in benchmark_name_list:
exp_dir = join(exp_root, f'{model_type}/results_{epoch}_{add_name}/benchmark/{benchmark_name}')
gt_img_dir = join(exp_dir, 'HR')
sr_img_dir = join(exp_dir, 'SR')
data_name = benchmark_name[5:]
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
if __name__ == '__main__':
main()