Spaces:
Runtime error
Runtime error
File size: 4,509 Bytes
193c713 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import ssl
from os.path import join
from pathlib import Path
from statistics import mean
parent_path = Path(__file__).absolute().parent.parent
parent_path = os.path.abspath(parent_path)
os.environ["CURL_CA_BUNDLE"] = ""
ssl._create_default_https_context = ssl._create_unverified_context
cache_path = os.path.join(parent_path, 'cache')
os.environ["HF_DATASETS_CACHE"] = cache_path
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["torch_HOME"] = cache_path
import PIL
import numpy as np
import pandas as pd
import pyiqa
import torch
from PIL import Image
from tqdm import tqdm
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
metric_dict = {
'psnr-Y': pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr'),
'ssim': pyiqa.create_metric('ssim', color_space='ycbcr'),
'fid': pyiqa.create_metric('fid'),
}
def load_img(path, target_size=None):
image = Image.open(path).convert("RGB")
if target_size:
h, w = target_size
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
def eval_img_IQA(gt_dir, sr_dir, excel_path, metric_list, exp_name, data_name):
gt_img_list = os.listdir(gt_dir)
iqa_result = {}
for metric in metric_list:
iqa_metric = metric_dict[metric].to(device)
score_fr_list = []
if metric == 'fid':
score_fr = iqa_metric(sr_dir, gt_dir)
iqa_result[metric] = float(score_fr)
print(f'{metric}: {float(score_fr)}')
else:
for img_name in tqdm(gt_img_list):
base_name = img_name.split('.')[0]
sr_img_name = f'{base_name}.png'
gt_img_path = join(gt_dir, img_name)
sr_img_path = join(sr_dir, sr_img_name)
if not os.path.exists(sr_img_path):
print(f'File not exist: {sr_img_path}')
continue
gt_img = load_img(gt_img_path, target_size=None)
target_size = gt_img.shape[2:]
sr_img = load_img(sr_img_path, target_size=target_size)
score_fr = iqa_metric(sr_img, gt_img)
if score_fr.shape == (1,):
score_fr = score_fr[0]
if isinstance(score_fr, torch.Tensor):
score_fr = float(score_fr.cpu().numpy())
else:
score_fr = float(score_fr)
score_fr_list.append(score_fr)
mean_score = mean(score_fr_list)
iqa_result[metric] = float(mean_score)
print(f'{metric}: {mean_score}')
if os.path.exists(excel_path):
df = pd.read_excel(excel_path)
else:
df = pd.DataFrame(columns=['exp'])
new_index = len(df.index)
exp_name = int(exp_name)
if exp_name in df['exp'].to_list():
new_index = df[df['exp'] == exp_name].index.tolist()[0]
else:
df.loc[new_index, 'exp'] = exp_name
for index, metric in enumerate(metric_list):
df_metric = f'{data_name}-{metric}'
if df_metric not in df.columns.tolist():
df[df_metric] = ''
df.loc[new_index, df_metric] = iqa_result[metric]
df.sort_values(by='exp', inplace=True)
df.to_excel(excel_path, startcol=0, index=False)
def main():
epoch = 400000
add_name = ''
exp_root = '/home/ma-user/work/code/SRDiff-main/checkpoints'
model_type_list = ['diffsr_df2k4x_sam-pl_qs-zero']
metric_list = ['psnr-Y', 'ssim', 'fid']
benchmark_name_list = ['test_Set5', 'test_Set14', 'test_Urban100', 'test_Manga109', 'test_BSDS100']
# if benchmark:
for model_type in model_type_list:
excel_path = join(exp_root, model_type, f'IQA-val-{model_type}.xls')
for benchmark_name in benchmark_name_list:
exp_dir = join(exp_root, f'{model_type}/results_{epoch}_{add_name}/benchmark/{benchmark_name}')
gt_img_dir = join(exp_dir, 'HR')
sr_img_dir = join(exp_dir, 'SR')
data_name = benchmark_name[5:]
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
if __name__ == '__main__':
main()
|