File size: 1,718 Bytes
7aefe45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
__version__ = '0.3.0'

import torch
from einops import rearrange, repeat

from .inception import InceptionV3
from .fid_score import calculate_frechet_distance


class PytorchFIDFactory(torch.nn.Module):
    """

   Args:
       channels:
       inception_block_idx:

    Examples:
    >>> fid_factory =  PytorchFIDFactory()
    >>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
    >>> print(fid_score)
   """

    def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
        super().__init__()
        self.channels = channels

        # load models
        assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
        self.inception_v3 = InceptionV3([block_idx])

    @torch.no_grad()
    def calculate_activation_statistics(self, samples):
        features = self.inception_v3(samples)[0]
        features = rearrange(features, '... 1 1 -> ...')

        mu = torch.mean(features, dim=0).cpu()
        sigma = torch.cov(features).cpu()
        return mu, sigma

    def score(self, real_samples, fake_samples):
        if self.channels == 1:
            real_samples, fake_samples = map(
                lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
            )

        min_batch = min(real_samples.shape[0], fake_samples.shape[0])
        real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))

        m1, s1 = self.calculate_activation_statistics(real_samples)
        m2, s2 = self.calculate_activation_statistics(fake_samples)

        fid_value = calculate_frechet_distance(m1, s1, m2, s2)
        return fid_value