File size: 2,803 Bytes
6d314be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import os
import sys
from pathlib import Path

import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.image.fid import FrechetInceptionDistance
from tqdm.auto import tqdm

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from models.Encoders import ClipModel
from utils.seed import set_seed
from utils.train import parallel_load_images
from utils.image_utils import list_image_files


def name_path(pair):
    name, path = pair.split(',')
    return name, Path(path)


@torch.inference_mode()
def compute_fid_datasets(datasets, target='celeba', device=torch.device('cuda'), CLIP=False, seed=3407):
    set_seed(seed)
    result = {}

    if CLIP:
        fid = FrechetInceptionDistance(feature=ClipModel(), reset_real_features=False, normalize=False)
    else:
        fid = FrechetInceptionDistance(reset_real_features=False, normalize=False)
    fid.to(device).eval()

    real_dataloader = DataLoader(TensorDataset(datasets[target]), batch_size=128)
    for batch in tqdm(real_dataloader):
        batch = batch[0].to(device)
        fid.update(batch, real=True)

    for key, tensor in datasets.items():
        if key == target:
            continue
        fid.reset()

        fake_dataloader = DataLoader(TensorDataset(tensor), batch_size=128)
        for batch in tqdm(fake_dataloader):
            batch = batch[0].to(device)
            fid.update(batch, real=False)
        result[key] = fid.compute().item()
    return result


def main(args):
    datasets = {}

    source = args.source_dataset.name
    datasets[source] = parallel_load_images(args.source_dataset, list_image_files(args.source_dataset))

    for method, path_dataset in args.methods_dataset:
        datasets[method] = parallel_load_images(path_dataset, list_image_files(path_dataset))

    FIDs = compute_fid_datasets(datasets, target=source, CLIP=False)
    df_fid = pd.DataFrame.from_dict(FIDs, orient='index', columns=['FID'])

    FIDs_CLIP = compute_fid_datasets(datasets, target=source, CLIP=True)
    df_clip = pd.DataFrame.from_dict(FIDs_CLIP, orient='index', columns=['FID_CLIP'])

    df_result = pd.concat([df_fid, df_clip], axis=1).round(2)
    print(df_result)

    os.makedirs(args.output.parent, exist_ok=True)
    df_result.to_csv(args.output, index=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Compute metrics')
    parser.add_argument('--source_dataset', type=Path, help='Dataset with real faces')
    parser.add_argument('--methods_dataset', type=name_path, nargs='+', help='Datasets after applying the method')
    parser.add_argument('--output', type=Path, default='logs/metric.csv', help='Folder for saving logs')
    args = parser.parse_args()

    main(args)