Spaces:
Build error
Build error
File size: 2,803 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import argparse
import os
import sys
from pathlib import Path
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.image.fid import FrechetInceptionDistance
from tqdm.auto import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from models.Encoders import ClipModel
from utils.seed import set_seed
from utils.train import parallel_load_images
from utils.image_utils import list_image_files
def name_path(pair):
name, path = pair.split(',')
return name, Path(path)
@torch.inference_mode()
def compute_fid_datasets(datasets, target='celeba', device=torch.device('cuda'), CLIP=False, seed=3407):
set_seed(seed)
result = {}
if CLIP:
fid = FrechetInceptionDistance(feature=ClipModel(), reset_real_features=False, normalize=False)
else:
fid = FrechetInceptionDistance(reset_real_features=False, normalize=False)
fid.to(device).eval()
real_dataloader = DataLoader(TensorDataset(datasets[target]), batch_size=128)
for batch in tqdm(real_dataloader):
batch = batch[0].to(device)
fid.update(batch, real=True)
for key, tensor in datasets.items():
if key == target:
continue
fid.reset()
fake_dataloader = DataLoader(TensorDataset(tensor), batch_size=128)
for batch in tqdm(fake_dataloader):
batch = batch[0].to(device)
fid.update(batch, real=False)
result[key] = fid.compute().item()
return result
def main(args):
datasets = {}
source = args.source_dataset.name
datasets[source] = parallel_load_images(args.source_dataset, list_image_files(args.source_dataset))
for method, path_dataset in args.methods_dataset:
datasets[method] = parallel_load_images(path_dataset, list_image_files(path_dataset))
FIDs = compute_fid_datasets(datasets, target=source, CLIP=False)
df_fid = pd.DataFrame.from_dict(FIDs, orient='index', columns=['FID'])
FIDs_CLIP = compute_fid_datasets(datasets, target=source, CLIP=True)
df_clip = pd.DataFrame.from_dict(FIDs_CLIP, orient='index', columns=['FID_CLIP'])
df_result = pd.concat([df_fid, df_clip], axis=1).round(2)
print(df_result)
os.makedirs(args.output.parent, exist_ok=True)
df_result.to_csv(args.output, index=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Compute metrics')
parser.add_argument('--source_dataset', type=Path, help='Dataset with real faces')
parser.add_argument('--methods_dataset', type=name_path, nargs='+', help='Datasets after applying the method')
parser.add_argument('--output', type=Path, default='logs/metric.csv', help='Folder for saving logs')
args = parser.parse_args()
main(args)
|