|
import numpy as np |
|
import pytest |
|
import torch |
|
from PIL import Image |
|
|
|
from pytorch_fid import fid_score, inception |
|
|
|
|
|
@pytest.fixture |
|
def device(): |
|
return torch.device("cpu") |
|
|
|
|
|
def test_calculate_fid_given_statistics(mocker, tmp_path, device): |
|
dim = 2048 |
|
m1, m2 = np.zeros((dim,)), np.ones((dim,)) |
|
sigma = np.eye(dim) |
|
|
|
def dummy_statistics(path, model, batch_size, dims, device, num_workers): |
|
if path.endswith("1"): |
|
return m1, sigma |
|
elif path.endswith("2"): |
|
return m2, sigma |
|
else: |
|
raise ValueError |
|
|
|
mocker.patch( |
|
"pytorch_fid.fid_score.compute_statistics_of_path", side_effect=dummy_statistics |
|
) |
|
|
|
dir_names = ["1", "2"] |
|
paths = [] |
|
for name in dir_names: |
|
path = tmp_path / name |
|
path.mkdir() |
|
paths.append(str(path)) |
|
|
|
fid_value = fid_score.calculate_fid_given_paths( |
|
paths, batch_size=dim, device=device, dims=dim, num_workers=0 |
|
) |
|
|
|
|
|
assert fid_value == np.sum((m1 - m2) ** 2) |
|
|
|
|
|
def test_compute_statistics_of_path(mocker, tmp_path, device): |
|
model = mocker.MagicMock(inception.InceptionV3)() |
|
model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)] |
|
|
|
size = (4, 4, 3) |
|
arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)] |
|
images = [(arr * 255).astype(np.uint8) for arr in arrays] |
|
|
|
paths = [] |
|
for idx, image in enumerate(images): |
|
paths.append(str(tmp_path / "{}.png".format(idx))) |
|
Image.fromarray(image, mode="RGB").save(paths[-1]) |
|
|
|
stats = fid_score.compute_statistics_of_path( |
|
str(tmp_path), |
|
model, |
|
batch_size=len(images), |
|
dims=3, |
|
device=device, |
|
num_workers=0, |
|
) |
|
|
|
assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3) |
|
assert np.allclose(stats[1], np.ones((3, 3)) * 0.25) |
|
|
|
|
|
def test_compute_statistics_of_path_from_file(mocker, tmp_path, device): |
|
model = mocker.MagicMock(inception.InceptionV3)() |
|
|
|
mu = np.random.randn(5) |
|
sigma = np.random.randn(5, 5) |
|
|
|
path = tmp_path / "stats.npz" |
|
with path.open("wb") as f: |
|
np.savez(f, mu=mu, sigma=sigma) |
|
|
|
stats = fid_score.compute_statistics_of_path( |
|
str(path), model, batch_size=1, dims=5, device=device, num_workers=0 |
|
) |
|
|
|
assert np.allclose(stats[0], mu) |
|
assert np.allclose(stats[1], sigma) |
|
|
|
|
|
def test_image_types(tmp_path): |
|
in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255 |
|
in_image = Image.fromarray(in_arr, mode="RGB") |
|
|
|
paths = [] |
|
for ext in fid_score.IMAGE_EXTENSIONS: |
|
paths.append(str(tmp_path / "img.{}".format(ext))) |
|
in_image.save(paths[-1]) |
|
|
|
dataset = fid_score.ImagePathDataset(paths) |
|
|
|
for img in dataset: |
|
assert np.allclose(np.array(img), in_arr) |
|
|