|
__version__ = '0.3.0' |
|
|
|
import torch |
|
from einops import rearrange, repeat |
|
|
|
from .inception import InceptionV3 |
|
from .fid_score import calculate_frechet_distance |
|
|
|
|
|
class PytorchFIDFactory(torch.nn.Module): |
|
""" |
|
|
|
Args: |
|
channels: |
|
inception_block_idx: |
|
|
|
Examples: |
|
>>> fid_factory = PytorchFIDFactory() |
|
>>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images) |
|
>>> print(fid_score) |
|
""" |
|
|
|
def __init__(self, channels: int = 3, inception_block_idx: int = 2048): |
|
super().__init__() |
|
self.channels = channels |
|
|
|
|
|
assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM |
|
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx] |
|
self.inception_v3 = InceptionV3([block_idx]) |
|
|
|
@torch.no_grad() |
|
def calculate_activation_statistics(self, samples): |
|
features = self.inception_v3(samples)[0] |
|
features = rearrange(features, '... 1 1 -> ...') |
|
|
|
mu = torch.mean(features, dim=0).cpu() |
|
sigma = torch.cov(features).cpu() |
|
return mu, sigma |
|
|
|
def score(self, real_samples, fake_samples): |
|
if self.channels == 1: |
|
real_samples, fake_samples = map( |
|
lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples) |
|
) |
|
|
|
min_batch = min(real_samples.shape[0], fake_samples.shape[0]) |
|
real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples)) |
|
|
|
m1, s1 = self.calculate_activation_statistics(real_samples) |
|
m2, s2 = self.calculate_activation_statistics(fake_samples) |
|
|
|
fid_value = calculate_frechet_distance(m1, s1, m2, s2) |
|
return fid_value |
|
|