newTryOn / scripts /fid_metric.py
amanSethSmava
new commit
6d314be
raw
history blame
2.8 kB
import argparse
import os
import sys
from pathlib import Path
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.image.fid import FrechetInceptionDistance
from tqdm.auto import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from models.Encoders import ClipModel
from utils.seed import set_seed
from utils.train import parallel_load_images
from utils.image_utils import list_image_files
def name_path(pair):
name, path = pair.split(',')
return name, Path(path)
@torch.inference_mode()
def compute_fid_datasets(datasets, target='celeba', device=torch.device('cuda'), CLIP=False, seed=3407):
set_seed(seed)
result = {}
if CLIP:
fid = FrechetInceptionDistance(feature=ClipModel(), reset_real_features=False, normalize=False)
else:
fid = FrechetInceptionDistance(reset_real_features=False, normalize=False)
fid.to(device).eval()
real_dataloader = DataLoader(TensorDataset(datasets[target]), batch_size=128)
for batch in tqdm(real_dataloader):
batch = batch[0].to(device)
fid.update(batch, real=True)
for key, tensor in datasets.items():
if key == target:
continue
fid.reset()
fake_dataloader = DataLoader(TensorDataset(tensor), batch_size=128)
for batch in tqdm(fake_dataloader):
batch = batch[0].to(device)
fid.update(batch, real=False)
result[key] = fid.compute().item()
return result
def main(args):
datasets = {}
source = args.source_dataset.name
datasets[source] = parallel_load_images(args.source_dataset, list_image_files(args.source_dataset))
for method, path_dataset in args.methods_dataset:
datasets[method] = parallel_load_images(path_dataset, list_image_files(path_dataset))
FIDs = compute_fid_datasets(datasets, target=source, CLIP=False)
df_fid = pd.DataFrame.from_dict(FIDs, orient='index', columns=['FID'])
FIDs_CLIP = compute_fid_datasets(datasets, target=source, CLIP=True)
df_clip = pd.DataFrame.from_dict(FIDs_CLIP, orient='index', columns=['FID_CLIP'])
df_result = pd.concat([df_fid, df_clip], axis=1).round(2)
print(df_result)
os.makedirs(args.output.parent, exist_ok=True)
df_result.to_csv(args.output, index=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Compute metrics')
parser.add_argument('--source_dataset', type=Path, help='Dataset with real faces')
parser.add_argument('--methods_dataset', type=name_path, nargs='+', help='Datasets after applying the method')
parser.add_argument('--output', type=Path, default='logs/metric.csv', help='Folder for saving logs')
args = parser.parse_args()
main(args)