Spaces:
Running
Running
antoinedelplace
commited on
Commit
·
207ef6f
0
Parent(s):
First commit
Browse files- .gitignore +5 -0
- README.md +6 -0
- asp/data/__init__.py +98 -0
- asp/data/aligned_dataset.py +96 -0
- asp/data/base_dataset.py +230 -0
- asp/data/image_folder.py +66 -0
- asp/experiments/__init__.py +54 -0
- asp/experiments/__main__.py +87 -0
- asp/experiments/mist_launcher.py +66 -0
- asp/experiments/pretrained_launcher.py +61 -0
- asp/experiments/tmux_launcher.py +215 -0
- asp/models/__init__.py +67 -0
- asp/models/asp_loss.py +97 -0
- asp/models/base_model.py +258 -0
- asp/models/cpt_model.py +261 -0
- asp/models/cut_model.py +214 -0
- asp/models/gauss_pyramid.py +42 -0
- asp/models/networks.py +1422 -0
- asp/models/patchnce.py +55 -0
- asp/options/__init__.py +1 -0
- asp/options/base_options.py +167 -0
- asp/options/test_options.py +21 -0
- asp/options/train_options.py +44 -0
- asp/util/__init__.py +2 -0
- asp/util/fdlutil.py +422 -0
- asp/util/fid.py +288 -0
- asp/util/general_utils.py +73 -0
- asp/util/get_data.py +110 -0
- asp/util/html.py +86 -0
- asp/util/image_pool.py +54 -0
- asp/util/inception.py +328 -0
- asp/util/kid_score.py +450 -0
- asp/util/perceptual.py +347 -0
- asp/util/util.py +220 -0
- asp/util/visualizer.py +242 -0
- main.py +90 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled
|
2 |
+
__pycache__/
|
3 |
+
|
4 |
+
# Environment
|
5 |
+
/venv/
|
README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# H&E to IHC translation
|
2 |
+
Based on Adaptive Supervised PatchNCE Loss for Learning H&E-to-IHC Stain Translation with Inconsistent Groundtruth Image Pairs (MICCAI 2023)
|
3 |
+
|
4 |
+
Original folder: [lifangda01/AdaptiveSupervisedPatchNCE](https://github.com/lifangda01/AdaptiveSupervisedPatchNCE)
|
5 |
+
|
6 |
+
Original paper: [](https://arxiv.org/pdf/2303.06193)
|
asp/data/__init__.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import importlib
|
14 |
+
import torch.utils.data
|
15 |
+
from asp.data.base_dataset import BaseDataset
|
16 |
+
|
17 |
+
|
18 |
+
def find_dataset_using_name(dataset_name):
|
19 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
20 |
+
|
21 |
+
In the file, the class called DatasetNameDataset() will
|
22 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
23 |
+
and it is case-insensitive.
|
24 |
+
"""
|
25 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
26 |
+
datasetlib = importlib.import_module(dataset_filename)
|
27 |
+
|
28 |
+
dataset = None
|
29 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
30 |
+
for name, cls in datasetlib.__dict__.items():
|
31 |
+
if name.lower() == target_dataset_name.lower() \
|
32 |
+
and issubclass(cls, BaseDataset):
|
33 |
+
dataset = cls
|
34 |
+
|
35 |
+
if dataset is None:
|
36 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
37 |
+
|
38 |
+
return dataset
|
39 |
+
|
40 |
+
|
41 |
+
def get_option_setter(dataset_name):
|
42 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
43 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
44 |
+
return dataset_class.modify_commandline_options
|
45 |
+
|
46 |
+
|
47 |
+
def create_dataset(opt):
|
48 |
+
"""Create a dataset given the option.
|
49 |
+
|
50 |
+
This function wraps the class CustomDatasetDataLoader.
|
51 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
52 |
+
|
53 |
+
Example:
|
54 |
+
>>> from data import create_dataset
|
55 |
+
>>> dataset = create_dataset(opt)
|
56 |
+
"""
|
57 |
+
data_loader = CustomDatasetDataLoader(opt)
|
58 |
+
dataset = data_loader.load_data()
|
59 |
+
return dataset
|
60 |
+
|
61 |
+
|
62 |
+
class CustomDatasetDataLoader():
|
63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
64 |
+
|
65 |
+
def __init__(self, opt):
|
66 |
+
"""Initialize this class
|
67 |
+
|
68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
69 |
+
Step 2: create a multi-threaded data loader.
|
70 |
+
"""
|
71 |
+
self.opt = opt
|
72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
73 |
+
self.dataset = dataset_class(opt)
|
74 |
+
print("dataset [%s] was created" % type(self.dataset).__name__)
|
75 |
+
self.dataloader = torch.utils.data.DataLoader(
|
76 |
+
self.dataset,
|
77 |
+
batch_size=opt.batch_size,
|
78 |
+
shuffle=not opt.serial_batches,
|
79 |
+
num_workers=int(opt.num_threads),
|
80 |
+
drop_last=True if opt.isTrain else False,
|
81 |
+
)
|
82 |
+
|
83 |
+
def set_epoch(self, epoch):
|
84 |
+
self.dataset.current_epoch = epoch
|
85 |
+
|
86 |
+
def load_data(self):
|
87 |
+
return self
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
"""Return the number of data in the dataset"""
|
91 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
92 |
+
|
93 |
+
def __iter__(self):
|
94 |
+
"""Return a batch of data"""
|
95 |
+
for i, data in enumerate(self.dataloader):
|
96 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
97 |
+
break
|
98 |
+
yield data
|
asp/data/aligned_dataset.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
|
6 |
+
from data.base_dataset import BaseDataset, get_transform
|
7 |
+
from data.image_folder import make_dataset
|
8 |
+
from PIL import Image
|
9 |
+
import random
|
10 |
+
import util.util as util
|
11 |
+
|
12 |
+
|
13 |
+
class AlignedDataset(BaseDataset):
|
14 |
+
"""
|
15 |
+
This dataset class can load aligned/paired datasets.
|
16 |
+
|
17 |
+
It requires two directories to host training images from domain A '/path/to/data/trainA'
|
18 |
+
and from domain B '/path/to/data/trainB' respectively.
|
19 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
20 |
+
Similarly, you need to prepare two directories:
|
21 |
+
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, opt):
|
25 |
+
"""Initialize this dataset class.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
29 |
+
"""
|
30 |
+
BaseDataset.__init__(self, opt)
|
31 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
|
32 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
|
33 |
+
|
34 |
+
if opt.phase == "test" and not os.path.exists(self.dir_A) \
|
35 |
+
and os.path.exists(os.path.join(opt.dataroot, "valA")):
|
36 |
+
self.dir_A = os.path.join(opt.dataroot, "valA")
|
37 |
+
self.dir_B = os.path.join(opt.dataroot, "valB")
|
38 |
+
|
39 |
+
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
40 |
+
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
|
41 |
+
|
42 |
+
self.A_size = len(self.A_paths) # get the size of dataset A
|
43 |
+
self.B_size = len(self.B_paths) # get the size of dataset B
|
44 |
+
assert self.A_size == self.B_size
|
45 |
+
|
46 |
+
def __getitem__(self, index):
|
47 |
+
"""Return a data point and its metadata information.
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
index (int) -- a random integer for data indexing
|
51 |
+
|
52 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
53 |
+
A (tensor) -- an image in the input domain
|
54 |
+
B (tensor) -- its corresponding image in the target domain
|
55 |
+
A_paths (str) -- image paths
|
56 |
+
B_paths (str) -- image paths
|
57 |
+
"""
|
58 |
+
if self.opt.serial_batches: # make sure index is within then range
|
59 |
+
index_B = index % self.B_size
|
60 |
+
else: # randomize the index for domain B to avoid fixed pairs.
|
61 |
+
index = random.randint(0, self.A_size - 1)
|
62 |
+
index_B = index % self.B_size
|
63 |
+
|
64 |
+
A_path = self.A_paths[index] # make sure index is within then range
|
65 |
+
B_path = self.B_paths[index_B]
|
66 |
+
|
67 |
+
assert A_path == B_path.replace('trainB', 'trainA').replace('valB', 'valA').replace('testB', 'testA')
|
68 |
+
|
69 |
+
A_img = Image.open(A_path).convert('RGB')
|
70 |
+
B_img = Image.open(B_path).convert('RGB')
|
71 |
+
|
72 |
+
# Apply image transformation
|
73 |
+
# For CUT/FastCUT mode, if in finetuning phase (learning rate is decaying),
|
74 |
+
# do not perform resize-crop data augmentation of CycleGAN.
|
75 |
+
is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
|
76 |
+
modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
|
77 |
+
transform = get_transform(modified_opt)
|
78 |
+
|
79 |
+
# FDL: synchronize transforms
|
80 |
+
seed = np.random.randint(2147483647) # make a seed with numpy generator
|
81 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
82 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
83 |
+
A = transform(A_img)
|
84 |
+
random.seed(seed) # apply this seed to target tranfsorms
|
85 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
86 |
+
B = transform(B_img)
|
87 |
+
|
88 |
+
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
"""Return the total number of images in the dataset.
|
92 |
+
|
93 |
+
As we have two datasets with potentially different number of images,
|
94 |
+
we take a maximum of
|
95 |
+
"""
|
96 |
+
return max(self.A_size, self.B_size)
|
asp/data/base_dataset.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(data.Dataset, ABC):
|
14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
15 |
+
|
16 |
+
To create a subclass, you need to implement the following four functions:
|
17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
18 |
+
-- <__len__>: return the size of dataset.
|
19 |
+
-- <__getitem__>: get a data point.
|
20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
"""Initialize the class; save the options in the class
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
28 |
+
"""
|
29 |
+
self.opt = opt
|
30 |
+
self.root = opt.dataroot
|
31 |
+
self.current_epoch = 0
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def modify_commandline_options(parser, is_train):
|
35 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
parser -- original option parser
|
39 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
the modified parser.
|
43 |
+
"""
|
44 |
+
return parser
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def __len__(self):
|
48 |
+
"""Return the total number of images in the dataset."""
|
49 |
+
return 0
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def __getitem__(self, index):
|
53 |
+
"""Return a data point and its metadata information.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
index - - a random integer for data indexing
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
60 |
+
"""
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
def get_params(opt, size):
|
65 |
+
w, h = size
|
66 |
+
new_h = h
|
67 |
+
new_w = w
|
68 |
+
if opt.preprocess == 'resize_and_crop':
|
69 |
+
new_h = new_w = opt.load_size
|
70 |
+
elif opt.preprocess == 'scale_width_and_crop':
|
71 |
+
new_w = opt.load_size
|
72 |
+
new_h = opt.load_size * h // w
|
73 |
+
|
74 |
+
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
75 |
+
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
76 |
+
|
77 |
+
flip = random.random() > 0.5
|
78 |
+
|
79 |
+
return {'crop_pos': (x, y), 'flip': flip}
|
80 |
+
|
81 |
+
|
82 |
+
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
|
83 |
+
transform_list = []
|
84 |
+
if grayscale:
|
85 |
+
transform_list.append(transforms.Grayscale(1))
|
86 |
+
if 'fixsize' in opt.preprocess:
|
87 |
+
transform_list.append(transforms.Resize(params["size"], method))
|
88 |
+
if 'resize' in opt.preprocess:
|
89 |
+
osize = [opt.load_size, opt.load_size]
|
90 |
+
if "gta2cityscapes" in opt.dataroot:
|
91 |
+
osize[0] = opt.load_size // 2
|
92 |
+
transform_list.append(transforms.Resize(osize, method))
|
93 |
+
elif 'scale_width' in opt.preprocess:
|
94 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
|
95 |
+
elif 'scale_shortside' in opt.preprocess:
|
96 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
|
97 |
+
|
98 |
+
if 'zoom' in opt.preprocess:
|
99 |
+
if params is None:
|
100 |
+
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
|
101 |
+
else:
|
102 |
+
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
|
103 |
+
|
104 |
+
if 'crop' in opt.preprocess:
|
105 |
+
if params is None or 'crop_pos' not in params:
|
106 |
+
transform_list.append(transforms.RandomCrop(opt.crop_size))
|
107 |
+
else:
|
108 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
109 |
+
|
110 |
+
if 'patch' in opt.preprocess:
|
111 |
+
transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
|
112 |
+
|
113 |
+
if 'trim' in opt.preprocess:
|
114 |
+
transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
|
115 |
+
|
116 |
+
# if opt.preprocess == 'none':
|
117 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
|
118 |
+
|
119 |
+
if not opt.no_flip:
|
120 |
+
if params is None or 'flip' not in params:
|
121 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
122 |
+
elif 'flip' in params:
|
123 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
124 |
+
|
125 |
+
if convert:
|
126 |
+
transform_list += [transforms.ToTensor()]
|
127 |
+
if grayscale:
|
128 |
+
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
129 |
+
else:
|
130 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
131 |
+
return transforms.Compose(transform_list)
|
132 |
+
|
133 |
+
|
134 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
135 |
+
ow, oh = img.size
|
136 |
+
h = int(round(oh / base) * base)
|
137 |
+
w = int(round(ow / base) * base)
|
138 |
+
if h == oh and w == ow:
|
139 |
+
return img
|
140 |
+
|
141 |
+
return img.resize((w, h), method)
|
142 |
+
|
143 |
+
|
144 |
+
def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
|
145 |
+
if factor is None:
|
146 |
+
zoom_level = np.random.uniform(0.8, 1.0, size=[2])
|
147 |
+
else:
|
148 |
+
zoom_level = (factor[0], factor[1])
|
149 |
+
iw, ih = img.size
|
150 |
+
zoomw = max(crop_width, iw * zoom_level[0])
|
151 |
+
zoomh = max(crop_width, ih * zoom_level[1])
|
152 |
+
img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
|
153 |
+
return img
|
154 |
+
|
155 |
+
|
156 |
+
def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
|
157 |
+
ow, oh = img.size
|
158 |
+
shortside = min(ow, oh)
|
159 |
+
if shortside >= target_width:
|
160 |
+
return img
|
161 |
+
else:
|
162 |
+
scale = target_width / shortside
|
163 |
+
return img.resize((round(ow * scale), round(oh * scale)), method)
|
164 |
+
|
165 |
+
|
166 |
+
def __trim(img, trim_width):
|
167 |
+
ow, oh = img.size
|
168 |
+
if ow > trim_width:
|
169 |
+
xstart = np.random.randint(ow - trim_width)
|
170 |
+
xend = xstart + trim_width
|
171 |
+
else:
|
172 |
+
xstart = 0
|
173 |
+
xend = ow
|
174 |
+
if oh > trim_width:
|
175 |
+
ystart = np.random.randint(oh - trim_width)
|
176 |
+
yend = ystart + trim_width
|
177 |
+
else:
|
178 |
+
ystart = 0
|
179 |
+
yend = oh
|
180 |
+
return img.crop((xstart, ystart, xend, yend))
|
181 |
+
|
182 |
+
|
183 |
+
def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
|
184 |
+
ow, oh = img.size
|
185 |
+
if ow == target_width and oh >= crop_width:
|
186 |
+
return img
|
187 |
+
w = target_width
|
188 |
+
h = int(max(target_width * oh / ow, crop_width))
|
189 |
+
return img.resize((w, h), method)
|
190 |
+
|
191 |
+
|
192 |
+
def __crop(img, pos, size):
|
193 |
+
ow, oh = img.size
|
194 |
+
x1, y1 = pos
|
195 |
+
tw = th = size
|
196 |
+
if (ow > tw or oh > th):
|
197 |
+
return img.crop((x1, y1, x1 + tw, y1 + th))
|
198 |
+
return img
|
199 |
+
|
200 |
+
|
201 |
+
def __patch(img, index, size):
|
202 |
+
ow, oh = img.size
|
203 |
+
nw, nh = ow // size, oh // size
|
204 |
+
roomx = ow - nw * size
|
205 |
+
roomy = oh - nh * size
|
206 |
+
startx = np.random.randint(int(roomx) + 1)
|
207 |
+
starty = np.random.randint(int(roomy) + 1)
|
208 |
+
|
209 |
+
index = index % (nw * nh)
|
210 |
+
ix = index // nh
|
211 |
+
iy = index % nh
|
212 |
+
gridx = startx + ix * size
|
213 |
+
gridy = starty + iy * size
|
214 |
+
return img.crop((gridx, gridy, gridx + size, gridy + size))
|
215 |
+
|
216 |
+
|
217 |
+
def __flip(img, flip):
|
218 |
+
if flip:
|
219 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
220 |
+
return img
|
221 |
+
|
222 |
+
|
223 |
+
def __print_size_warning(ow, oh, w, h):
|
224 |
+
"""Print warning information about image size(only print once)"""
|
225 |
+
if not hasattr(__print_size_warning, 'has_printed'):
|
226 |
+
print("The image size needs to be a multiple of 4. "
|
227 |
+
"The loaded image size was (%d, %d), so it was adjusted to "
|
228 |
+
"(%d, %d). This adjustment will be done to all images "
|
229 |
+
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
230 |
+
__print_size_warning.has_printed = True
|
asp/data/image_folder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A modified image folder class
|
2 |
+
|
3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
4 |
+
so that this class can load images from both current directory and its subdirectories.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
import os.path
|
12 |
+
|
13 |
+
IMG_EXTENSIONS = [
|
14 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
15 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
16 |
+
'.tif', '.TIF', '.tiff', '.TIFF',
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def is_image_file(filename):
|
21 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
22 |
+
|
23 |
+
|
24 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
25 |
+
images = []
|
26 |
+
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
|
27 |
+
|
28 |
+
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
29 |
+
for fname in fnames:
|
30 |
+
if is_image_file(fname):
|
31 |
+
path = os.path.join(root, fname)
|
32 |
+
images.append(path)
|
33 |
+
return images[:min(max_dataset_size, len(images))]
|
34 |
+
|
35 |
+
|
36 |
+
def default_loader(path):
|
37 |
+
return Image.open(path).convert('RGB')
|
38 |
+
|
39 |
+
|
40 |
+
class ImageFolder(data.Dataset):
|
41 |
+
|
42 |
+
def __init__(self, root, transform=None, return_paths=False,
|
43 |
+
loader=default_loader):
|
44 |
+
imgs = make_dataset(root)
|
45 |
+
if len(imgs) == 0:
|
46 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
47 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
48 |
+
|
49 |
+
self.root = root
|
50 |
+
self.imgs = imgs
|
51 |
+
self.transform = transform
|
52 |
+
self.return_paths = return_paths
|
53 |
+
self.loader = loader
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
path = self.imgs[index]
|
57 |
+
img = self.loader(path)
|
58 |
+
if self.transform is not None:
|
59 |
+
img = self.transform(img)
|
60 |
+
if self.return_paths:
|
61 |
+
return img, path
|
62 |
+
else:
|
63 |
+
return img
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.imgs)
|
asp/experiments/__init__.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
|
5 |
+
def find_launcher_using_name(launcher_name):
|
6 |
+
# cur_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
# pythonfiles = glob.glob(cur_dir + '/**/*.py')
|
8 |
+
launcher_filename = "experiments.{}_launcher".format(launcher_name)
|
9 |
+
launcherlib = importlib.import_module(launcher_filename)
|
10 |
+
|
11 |
+
# In the file, the class called LauncherNameLauncher() will
|
12 |
+
# be instantiated. It has to be a subclass of BaseLauncher,
|
13 |
+
# and it is case-insensitive.
|
14 |
+
launcher = None
|
15 |
+
target_launcher_name = launcher_name.replace('_', '') + 'launcher'
|
16 |
+
for name, cls in launcherlib.__dict__.items():
|
17 |
+
if name.lower() == target_launcher_name.lower():
|
18 |
+
launcher = cls
|
19 |
+
|
20 |
+
if launcher is None:
|
21 |
+
raise ValueError("In %s.py, there should be a subclass of BaseLauncher "
|
22 |
+
"with class name that matches %s in lowercase." %
|
23 |
+
(launcher_filename, target_launcher_name))
|
24 |
+
|
25 |
+
return launcher
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
import sys
|
30 |
+
import pickle
|
31 |
+
|
32 |
+
assert len(sys.argv) >= 3
|
33 |
+
|
34 |
+
name = sys.argv[1]
|
35 |
+
Launcher = find_launcher_using_name(name)
|
36 |
+
|
37 |
+
cache = "/tmp/tmux_launcher/{}".format(name)
|
38 |
+
if os.path.isfile(cache):
|
39 |
+
instance = pickle.load(open(cache, 'r'))
|
40 |
+
else:
|
41 |
+
instance = Launcher()
|
42 |
+
|
43 |
+
cmd = sys.argv[2]
|
44 |
+
if cmd == "launch":
|
45 |
+
instance.launch()
|
46 |
+
elif cmd == "stop":
|
47 |
+
instance.stop()
|
48 |
+
elif cmd == "send":
|
49 |
+
expid = int(sys.argv[3])
|
50 |
+
cmd = int(sys.argv[4])
|
51 |
+
instance.send_command(expid, cmd)
|
52 |
+
|
53 |
+
os.makedirs("/tmp/tmux_launcher/", exist_ok=True)
|
54 |
+
pickle.dump(instance, open(cache, 'w'))
|
asp/experiments/__main__.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
|
5 |
+
def find_launcher_using_name(launcher_name):
|
6 |
+
# cur_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
# pythonfiles = glob.glob(cur_dir + '/**/*.py')
|
8 |
+
launcher_filename = "experiments.{}_launcher".format(launcher_name)
|
9 |
+
launcherlib = importlib.import_module(launcher_filename)
|
10 |
+
|
11 |
+
# In the file, the class called LauncherNameLauncher() will
|
12 |
+
# be instantiated. It has to be a subclass of BaseLauncher,
|
13 |
+
# and it is case-insensitive.
|
14 |
+
launcher = None
|
15 |
+
# target_launcher_name = launcher_name.replace('_', '') + 'launcher'
|
16 |
+
for name, cls in launcherlib.__dict__.items():
|
17 |
+
if name.lower() == "launcher":
|
18 |
+
launcher = cls
|
19 |
+
|
20 |
+
if launcher is None:
|
21 |
+
raise ValueError("In %s.py, there should be a class named Launcher")
|
22 |
+
|
23 |
+
return launcher
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
import argparse
|
28 |
+
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument('name')
|
31 |
+
parser.add_argument('cmd')
|
32 |
+
parser.add_argument('id', nargs='+', type=str)
|
33 |
+
parser.add_argument('--mode', default=None)
|
34 |
+
parser.add_argument('--which_epoch', default=None)
|
35 |
+
parser.add_argument('--continue_train', action='store_true')
|
36 |
+
parser.add_argument('--subdir', default='')
|
37 |
+
parser.add_argument('--title', default='')
|
38 |
+
parser.add_argument('--gpu_id', default=None, type=int)
|
39 |
+
parser.add_argument('--phase', default='test')
|
40 |
+
|
41 |
+
opt = parser.parse_args()
|
42 |
+
|
43 |
+
name = opt.name
|
44 |
+
Launcher = find_launcher_using_name(name)
|
45 |
+
|
46 |
+
instance = Launcher()
|
47 |
+
|
48 |
+
cmd = opt.cmd
|
49 |
+
ids = 'all' if 'all' in opt.id else [int(i) for i in opt.id]
|
50 |
+
if cmd == "launch":
|
51 |
+
instance.launch(ids, continue_train=opt.continue_train)
|
52 |
+
elif cmd == "stop":
|
53 |
+
instance.stop()
|
54 |
+
elif cmd == "send":
|
55 |
+
assert False
|
56 |
+
elif cmd == "close":
|
57 |
+
instance.close()
|
58 |
+
elif cmd == "dry":
|
59 |
+
instance.dry()
|
60 |
+
elif cmd == "relaunch":
|
61 |
+
instance.close()
|
62 |
+
instance.launch(ids, continue_train=opt.continue_train)
|
63 |
+
elif cmd == "run" or cmd == "train":
|
64 |
+
assert len(ids) == 1, '%s is invalid for run command' % (' '.join(opt.id))
|
65 |
+
expid = ids[0]
|
66 |
+
instance.run_command(instance.commands(), expid,
|
67 |
+
continue_train=opt.continue_train,
|
68 |
+
gpu_id=opt.gpu_id)
|
69 |
+
elif cmd == 'launch_test':
|
70 |
+
instance.launch(ids, test=True)
|
71 |
+
elif cmd == "run_test" or cmd == "test":
|
72 |
+
test_commands = instance.test_commands()
|
73 |
+
if ids == "all":
|
74 |
+
ids = list(range(len(test_commands)))
|
75 |
+
for expid in ids:
|
76 |
+
instance.run_command(test_commands, expid, opt.which_epoch,
|
77 |
+
gpu_id=opt.gpu_id)
|
78 |
+
if expid < len(ids) - 1:
|
79 |
+
os.system("sleep 5s")
|
80 |
+
elif cmd == "print_names":
|
81 |
+
instance.print_names(ids, test=False)
|
82 |
+
elif cmd == "print_test_names":
|
83 |
+
instance.print_names(ids, test=True)
|
84 |
+
elif cmd == "create_comparison_html":
|
85 |
+
instance.create_comparison_html(name, ids, opt.subdir, opt.title, opt.phase)
|
86 |
+
else:
|
87 |
+
raise ValueError("Command not recognized")
|
asp/experiments/mist_launcher.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .tmux_launcher import Options, TmuxLauncher
|
2 |
+
|
3 |
+
|
4 |
+
class Launcher(TmuxLauncher):
|
5 |
+
def common_options(self):
|
6 |
+
return [
|
7 |
+
# Command 0
|
8 |
+
Options(
|
9 |
+
dataroot="../data/BCI_dataset",
|
10 |
+
name="mist_her2_lambda_linear",
|
11 |
+
checkpoints_dir='../checkpoints',
|
12 |
+
model='cpt',
|
13 |
+
CUT_mode="FastCUT",
|
14 |
+
|
15 |
+
n_epochs=30, # number of epochs with the initial learning rate
|
16 |
+
n_epochs_decay=10, # number of epochs to linearly decay learning rate to zero
|
17 |
+
|
18 |
+
netD='n_layers', # ['basic', 'n_layers, 'pixel', 'patch'], 'specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
19 |
+
ndf=32,
|
20 |
+
netG='resnet_6blocks', # ['resnet_9blocks', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], 'specify generator architecture')
|
21 |
+
n_layers_D=5, # 'only used if netD==n_layers'
|
22 |
+
normG='instance', # ['instance, 'batch, 'none'], 'instance normalization or batch normalization for G')
|
23 |
+
normD='instance', # ['instance, 'batch, 'none'], 'instance normalization or batch normalization for D')
|
24 |
+
weight_norm='spectral',
|
25 |
+
|
26 |
+
lambda_GAN=1.0, # weight for GAN loss:GAN(G(X))
|
27 |
+
lambda_NCE=10.0, # weight for NCE loss: NCE(G(X), X)
|
28 |
+
nce_layers='0,4,8,12,16',
|
29 |
+
nce_T=0.07,
|
30 |
+
num_patches=256,
|
31 |
+
|
32 |
+
# FDL:
|
33 |
+
lambda_gp=10.0,
|
34 |
+
gp_weights='[0.015625,0.03125,0.0625,0.125,0.25,1.0]',
|
35 |
+
lambda_asp=10.0, # weight for NCE loss: NCE(G(X), X)
|
36 |
+
asp_loss_mode='lambda_linear',
|
37 |
+
|
38 |
+
dataset_mode='aligned', # chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
|
39 |
+
direction='AtoB',
|
40 |
+
# serial_batches='', # if true, takes images in order to make batches, otherwise takes them randomly
|
41 |
+
num_threads=15, # '# threads for loading data')
|
42 |
+
batch_size=1, # 'input batch size')
|
43 |
+
load_size=1024, # 'scale images to this size')
|
44 |
+
crop_size=512, # 'then crop to this size')
|
45 |
+
preprocess='crop', # ='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
|
46 |
+
# no_flip='',
|
47 |
+
flip_equivariance=False,
|
48 |
+
display_winsize=512, # display window size for both visdom and HTML
|
49 |
+
# display_id=0,
|
50 |
+
update_html_freq=100,
|
51 |
+
save_epoch_freq=5,
|
52 |
+
# print_freq=10,
|
53 |
+
),
|
54 |
+
]
|
55 |
+
|
56 |
+
def commands(self):
|
57 |
+
return ["python train.py " + str(opt) for opt in self.common_options()]
|
58 |
+
|
59 |
+
def test_commands(self):
|
60 |
+
opts = self.common_options()
|
61 |
+
phase = 'val'
|
62 |
+
for opt in opts:
|
63 |
+
opt.set(crop_size=1024, num_test=1000)
|
64 |
+
opt.remove('n_epochs', 'n_epochs_decay', 'update_html_freq',
|
65 |
+
'save_epoch_freq', 'continue_train', 'epoch_count')
|
66 |
+
return ["python test.py " + str(opt.set(phase=phase)) for opt in opts]
|
asp/experiments/pretrained_launcher.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .tmux_launcher import Options, TmuxLauncher
|
2 |
+
|
3 |
+
|
4 |
+
class Launcher(TmuxLauncher):
|
5 |
+
def common_options(self):
|
6 |
+
return [
|
7 |
+
# Command 0
|
8 |
+
Options(
|
9 |
+
# NOTE: download the resized (and compressed) val set from
|
10 |
+
# http://efrosgans.eecs.berkeley.edu/CUT/datasets/cityscapes_val_for_CUT.tar
|
11 |
+
dataroot="datasets/cityscapes/cityscapes_val/",
|
12 |
+
direction="BtoA",
|
13 |
+
phase="val",
|
14 |
+
name="cityscapes_cut_pretrained",
|
15 |
+
CUT_mode="CUT",
|
16 |
+
),
|
17 |
+
|
18 |
+
# Command 1
|
19 |
+
Options(
|
20 |
+
dataroot="./datasets/cityscapes_unaligned/cityscapes/",
|
21 |
+
direction="BtoA",
|
22 |
+
name="cityscapes_fastcut_pretrained",
|
23 |
+
CUT_mode="FastCUT",
|
24 |
+
),
|
25 |
+
|
26 |
+
# Command 2
|
27 |
+
Options(
|
28 |
+
dataroot="./datasets/horse2zebra/",
|
29 |
+
name="horse2zebra_cut_pretrained",
|
30 |
+
CUT_mode="CUT"
|
31 |
+
),
|
32 |
+
|
33 |
+
# Command 3
|
34 |
+
Options(
|
35 |
+
dataroot="./datasets/horse2zebra/",
|
36 |
+
name="horse2zebra_fastcut_pretrained",
|
37 |
+
CUT_mode="FastCUT",
|
38 |
+
),
|
39 |
+
|
40 |
+
# Command 4
|
41 |
+
Options(
|
42 |
+
dataroot="/mnt/cloudNAS3/fangda/CycleGANData/dog2cat",
|
43 |
+
name="cat2dog_cut_pretrained",
|
44 |
+
CUT_mode="CUT"
|
45 |
+
),
|
46 |
+
|
47 |
+
# Command 5
|
48 |
+
Options(
|
49 |
+
dataroot="./datasets/afhq/cat2dog/",
|
50 |
+
name="cat2dog_fastcut_pretrained",
|
51 |
+
CUT_mode="FastCUT",
|
52 |
+
),
|
53 |
+
|
54 |
+
|
55 |
+
]
|
56 |
+
|
57 |
+
def commands(self):
|
58 |
+
return ["python train.py " + str(opt) for opt in self.common_options()]
|
59 |
+
|
60 |
+
def test_commands(self):
|
61 |
+
return ["python test.py " + str(opt.set(num_test=500)) for opt in self.common_options()]
|
asp/experiments/tmux_launcher.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
experiment launcher using tmux panes
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import math
|
6 |
+
import GPUtil
|
7 |
+
import re
|
8 |
+
|
9 |
+
available_gpu_devices = None
|
10 |
+
|
11 |
+
|
12 |
+
class Options():
|
13 |
+
def __init__(self, *args, **kwargs):
|
14 |
+
self.args = []
|
15 |
+
self.kvs = {"gpu_ids": "0"}
|
16 |
+
self.set(*args, **kwargs)
|
17 |
+
|
18 |
+
def set(self, *args, **kwargs):
|
19 |
+
for a in args:
|
20 |
+
self.args.append(a)
|
21 |
+
for k, v in kwargs.items():
|
22 |
+
self.kvs[k] = v
|
23 |
+
|
24 |
+
return self
|
25 |
+
|
26 |
+
def remove(self, *args):
|
27 |
+
for a in args:
|
28 |
+
if a in self.args:
|
29 |
+
self.args.remove(a)
|
30 |
+
if a in self.kvs:
|
31 |
+
del self.kvs[a]
|
32 |
+
|
33 |
+
return self
|
34 |
+
|
35 |
+
def update(self, opt):
|
36 |
+
self.args += opt.args
|
37 |
+
self.kvs.update(opt.kvs)
|
38 |
+
return self
|
39 |
+
|
40 |
+
def __str__(self):
|
41 |
+
final = " ".join(self.args)
|
42 |
+
for k, v in self.kvs.items():
|
43 |
+
final += " --{} {}".format(k, v)
|
44 |
+
|
45 |
+
return final
|
46 |
+
|
47 |
+
def clone(self):
|
48 |
+
opt = Options()
|
49 |
+
opt.args = self.args.copy()
|
50 |
+
opt.kvs = self.kvs.copy()
|
51 |
+
return opt
|
52 |
+
|
53 |
+
|
54 |
+
def grab_pattern(pattern, text):
|
55 |
+
found = re.search(pattern, text)
|
56 |
+
if found is not None:
|
57 |
+
return found[1]
|
58 |
+
else:
|
59 |
+
None
|
60 |
+
|
61 |
+
|
62 |
+
# http://code.activestate.com/recipes/252177-find-the-common-beginning-in-a-list-of-strings/
|
63 |
+
def findcommonstart(strlist):
|
64 |
+
prefix_len = ([min([x[0] == elem for elem in x])
|
65 |
+
for x in zip(*strlist)] + [0]).index(0)
|
66 |
+
prefix_len = max(1, prefix_len - 4)
|
67 |
+
return strlist[0][:prefix_len]
|
68 |
+
|
69 |
+
|
70 |
+
class TmuxLauncher():
|
71 |
+
def __init__(self):
|
72 |
+
super().__init__()
|
73 |
+
self.tmux_prepared = False
|
74 |
+
|
75 |
+
def prepare_tmux_panes(self, num_experiments, dry=False):
|
76 |
+
self.pane_per_window = 1
|
77 |
+
self.n_windows = int(math.ceil(num_experiments / self.pane_per_window))
|
78 |
+
print('preparing {} tmux panes'.format(num_experiments))
|
79 |
+
for w in range(self.n_windows):
|
80 |
+
if dry:
|
81 |
+
continue
|
82 |
+
window_name = "experiments_{}".format(w)
|
83 |
+
os.system("tmux new-window -n {}".format(window_name))
|
84 |
+
self.tmux_prepared = True
|
85 |
+
|
86 |
+
def refine_command(self, command, which_epoch, continue_train, gpu_id=None):
|
87 |
+
command = str(command)
|
88 |
+
if "--gpu_ids" in command:
|
89 |
+
gpu_ids = re.search(r'--gpu_ids ([\d,?]+)', command)[1]
|
90 |
+
else:
|
91 |
+
gpu_ids = "0"
|
92 |
+
|
93 |
+
gpu_ids = gpu_ids.split(",")
|
94 |
+
num_gpus = len(gpu_ids)
|
95 |
+
global available_gpu_devices
|
96 |
+
if available_gpu_devices is None and gpu_id is None:
|
97 |
+
available_gpu_devices = [str(g) for g in GPUtil.getAvailable(limit=8, maxMemory=0.5)]
|
98 |
+
if gpu_id is not None:
|
99 |
+
available_gpu_devices = [i for i in str(gpu_id)]
|
100 |
+
if len(available_gpu_devices) < num_gpus:
|
101 |
+
raise ValueError("{} GPU(s) required for the command {} is not available".format(num_gpus, command))
|
102 |
+
active_devices = ",".join(available_gpu_devices[:num_gpus])
|
103 |
+
if which_epoch is not None:
|
104 |
+
which_epoch = " --epoch %s " % which_epoch
|
105 |
+
else:
|
106 |
+
which_epoch = ""
|
107 |
+
command = "CUDA_VISIBLE_DEVICES={} {} {}".format(active_devices, command, which_epoch)
|
108 |
+
if continue_train:
|
109 |
+
command += " --continue_train "
|
110 |
+
|
111 |
+
# available_gpu_devices = [str(g) for g in GPUtil.getAvailable(limit=8, maxMemory=0.8)]
|
112 |
+
available_gpu_devices = available_gpu_devices[num_gpus:]
|
113 |
+
|
114 |
+
return command
|
115 |
+
|
116 |
+
def send_command(self, exp_id, command, dry=False, continue_train=False):
|
117 |
+
command = self.refine_command(command, None, continue_train=continue_train)
|
118 |
+
pane_name = "experiments_{windowid}.{paneid}".format(windowid=exp_id // self.pane_per_window,
|
119 |
+
paneid=exp_id % self.pane_per_window)
|
120 |
+
if dry is False:
|
121 |
+
os.system("tmux send-keys -t {} \"{}\" Enter".format(pane_name, command))
|
122 |
+
|
123 |
+
print("{}: {}".format(pane_name, command))
|
124 |
+
return pane_name
|
125 |
+
|
126 |
+
def run_command(self, command, ids, which_epoch=None, continue_train=False, gpu_id=None):
|
127 |
+
if type(command) is not list:
|
128 |
+
command = [command]
|
129 |
+
if ids is None:
|
130 |
+
ids = list(range(len(command)))
|
131 |
+
if type(ids) is not list:
|
132 |
+
ids = [ids]
|
133 |
+
|
134 |
+
for id in ids:
|
135 |
+
this_command = command[id]
|
136 |
+
refined_command = self.refine_command(this_command, which_epoch, continue_train=continue_train, gpu_id=gpu_id)
|
137 |
+
print(refined_command)
|
138 |
+
os.system(refined_command)
|
139 |
+
|
140 |
+
def commands(self):
|
141 |
+
return []
|
142 |
+
|
143 |
+
def launch(self, ids, test=False, dry=False, continue_train=False):
|
144 |
+
commands = self.test_commands() if test else self.commands()
|
145 |
+
if type(ids) is list:
|
146 |
+
commands = [commands[i] for i in ids]
|
147 |
+
if not self.tmux_prepared:
|
148 |
+
self.prepare_tmux_panes(len(commands), dry)
|
149 |
+
assert self.tmux_prepared
|
150 |
+
|
151 |
+
for i, command in enumerate(commands):
|
152 |
+
self.send_command(i, command, dry, continue_train=continue_train)
|
153 |
+
|
154 |
+
def dry(self):
|
155 |
+
self.launch(dry=True)
|
156 |
+
|
157 |
+
def stop(self):
|
158 |
+
num_experiments = len(self.commands())
|
159 |
+
self.pane_per_window = 4
|
160 |
+
self.n_windows = int(math.ceil(num_experiments / self.pane_per_window))
|
161 |
+
for w in range(self.n_windows):
|
162 |
+
window_name = "experiments_{}".format(w)
|
163 |
+
for i in range(self.pane_per_window):
|
164 |
+
os.system("tmux send-keys -t {window}.{pane} C-c".format(window=window_name, pane=i))
|
165 |
+
|
166 |
+
def close(self):
|
167 |
+
num_experiments = len(self.commands())
|
168 |
+
self.pane_per_window = 1
|
169 |
+
self.n_windows = int(math.ceil(num_experiments / self.pane_per_window))
|
170 |
+
for w in range(self.n_windows):
|
171 |
+
window_name = "experiments_{}".format(w)
|
172 |
+
os.system("tmux kill-window -t {}".format(window_name))
|
173 |
+
|
174 |
+
def print_names(self, ids, test=False):
|
175 |
+
if test:
|
176 |
+
cmds = self.test_commands()
|
177 |
+
else:
|
178 |
+
cmds = self.commands()
|
179 |
+
if type(ids) is list:
|
180 |
+
cmds = [cmds[i] for i in ids]
|
181 |
+
|
182 |
+
for cmdid, cmd in enumerate(cmds):
|
183 |
+
name = grab_pattern(r'--name ([^ ]+)', cmd)
|
184 |
+
print(name)
|
185 |
+
|
186 |
+
def create_comparison_html(self, expr_name, ids, subdir, title, phase):
|
187 |
+
cmds = self.test_commands()
|
188 |
+
if type(ids) is list:
|
189 |
+
cmds = [cmds[i] for i in ids]
|
190 |
+
|
191 |
+
no_easy_label = True
|
192 |
+
dirs = []
|
193 |
+
labels = []
|
194 |
+
for cmdid, cmd in enumerate(cmds):
|
195 |
+
name = grab_pattern(r'--name ([^ ]+)', cmd)
|
196 |
+
which_epoch = grab_pattern(r'--epoch ([^ ]+)', cmd)
|
197 |
+
if which_epoch is None:
|
198 |
+
which_epoch = "latest"
|
199 |
+
label = grab_pattern(r'--easy_label "([^"]+)"', cmd)
|
200 |
+
if label is None:
|
201 |
+
label = name
|
202 |
+
else:
|
203 |
+
no_easy_label = False
|
204 |
+
labels.append(label)
|
205 |
+
dir = "results/%s/%s_%s/%s/" % (name, phase, which_epoch, subdir)
|
206 |
+
dirs.append(dir)
|
207 |
+
|
208 |
+
commonprefix = findcommonstart(labels) if no_easy_label else ""
|
209 |
+
labels = ['"' + label[len(commonprefix):] + '"' for label in labels]
|
210 |
+
dirstr = ' '.join(dirs)
|
211 |
+
labelstr = ' '.join(labels)
|
212 |
+
|
213 |
+
command = "python ~/tools/html.py --web_dir_prefix results/comparison_ --name %s --dirs %s --labels %s --image_width 256" % (expr_name + '_' + title, dirstr, labelstr)
|
214 |
+
print(command)
|
215 |
+
os.system(command)
|
asp/models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from asp.models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
asp/models/asp_loss.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class AdaptiveSupervisedPatchNCELoss(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, opt):
|
11 |
+
super().__init__()
|
12 |
+
self.opt = opt
|
13 |
+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
14 |
+
self.mask_dtype = torch.bool
|
15 |
+
self.total_epochs = opt.n_epochs + opt.n_epochs_decay
|
16 |
+
|
17 |
+
def forward(self, feat_q, feat_k, current_epoch=-1):
|
18 |
+
num_patches = feat_q.shape[0]
|
19 |
+
dim = feat_q.shape[1]
|
20 |
+
feat_k = feat_k.detach()
|
21 |
+
|
22 |
+
# pos logit
|
23 |
+
l_pos = torch.bmm(
|
24 |
+
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
|
25 |
+
l_pos = l_pos.view(num_patches, 1)
|
26 |
+
|
27 |
+
# neg logit
|
28 |
+
|
29 |
+
# Should the negatives from the other samples of a minibatch be utilized?
|
30 |
+
# In CUT and FastCUT, we found that it's best to only include negatives
|
31 |
+
# from the same image. Therefore, we set
|
32 |
+
# --nce_includes_all_negatives_from_minibatch as False
|
33 |
+
# However, for single-image translation, the minibatch consists of
|
34 |
+
# crops from the "same" high-resolution image.
|
35 |
+
# Therefore, we will include the negatives from the entire minibatch.
|
36 |
+
if self.opt.nce_includes_all_negatives_from_minibatch:
|
37 |
+
# reshape features as if they are all negatives of minibatch of size 1.
|
38 |
+
batch_dim_for_bmm = 1
|
39 |
+
else:
|
40 |
+
batch_dim_for_bmm = self.opt.batch_size
|
41 |
+
|
42 |
+
# reshape features to batch size
|
43 |
+
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
44 |
+
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
45 |
+
npatches = feat_q.size(1)
|
46 |
+
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
47 |
+
|
48 |
+
# diagonal entries are similarity between same features, and hence meaningless.
|
49 |
+
# just fill the diagonal with very small number, which is exp(-10) and almost zero
|
50 |
+
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
51 |
+
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
52 |
+
l_neg = l_neg_curbatch.view(-1, npatches)
|
53 |
+
|
54 |
+
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
|
55 |
+
|
56 |
+
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
|
57 |
+
device=feat_q.device))
|
58 |
+
|
59 |
+
if self.opt.asp_loss_mode == 'none':
|
60 |
+
return loss
|
61 |
+
|
62 |
+
scheduler, lookup = self.opt.asp_loss_mode.split('_')[:2]
|
63 |
+
# Compute scheduling
|
64 |
+
t = (current_epoch - 1) / self.total_epochs
|
65 |
+
if scheduler == 'sigmoid':
|
66 |
+
p = 1 / (1 + np.exp((t - 0.5) * 10))
|
67 |
+
elif scheduler == 'linear':
|
68 |
+
p = 1 - t
|
69 |
+
elif scheduler == 'lambda':
|
70 |
+
k = 1 - self.opt.n_epochs_decay / self.total_epochs
|
71 |
+
m = 1 / (1 - k)
|
72 |
+
p = m - m * t if t >= k else 1.0
|
73 |
+
elif scheduler == 'zero':
|
74 |
+
p = 1.0
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Unrecognized scheduler: {scheduler}")
|
77 |
+
# Weight lookups
|
78 |
+
w0 = 1.0
|
79 |
+
x = l_pos.squeeze().detach()
|
80 |
+
if lookup == 'top':
|
81 |
+
x = torch.where(x > 0.0, x, torch.zeros_like(x))
|
82 |
+
w1 = torch.sqrt(1 - (x - 1) ** 2)
|
83 |
+
elif lookup == 'linear':
|
84 |
+
w1 = torch.relu(x)
|
85 |
+
elif lookup == 'bell':
|
86 |
+
sigma, mu, sc = 1, 0, 4
|
87 |
+
w1 = 1 / (sigma * np.sqrt(2 * torch.pi)) * torch.exp(-((x - 0.5) * sc - mu) ** 2 / (2 * sigma ** 2))
|
88 |
+
elif lookup == 'uniform':
|
89 |
+
w1 = torch.ones_like(x)
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Unrecognized lookup: {lookup}")
|
92 |
+
# Apply weights with schedule
|
93 |
+
w = p * w0 + (1 - p) * w1
|
94 |
+
# Normalize
|
95 |
+
w = w / w.sum() * len(w)
|
96 |
+
loss = loss * w
|
97 |
+
return loss
|
asp/models/base_model.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from . import networks
|
6 |
+
|
7 |
+
|
8 |
+
class BaseModel(ABC):
|
9 |
+
"""This class is an abstract base class (ABC) for models.
|
10 |
+
To create a subclass, you need to implement the following five functions:
|
11 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
12 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
13 |
+
-- <forward>: produce intermediate results.
|
14 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
15 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, opt):
|
19 |
+
"""Initialize the BaseModel class.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
23 |
+
|
24 |
+
When creating your custom class, you need to implement your own initialization.
|
25 |
+
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
|
26 |
+
Then, you need to define four lists:
|
27 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
28 |
+
-- self.model_names (str list): specify the images that you want to display and save.
|
29 |
+
-- self.visual_names (str list): define networks used in our training.
|
30 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
31 |
+
"""
|
32 |
+
self.opt = opt
|
33 |
+
self.gpu_ids = opt.gpu_ids
|
34 |
+
self.isTrain = opt.isTrain
|
35 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
36 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
37 |
+
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
38 |
+
torch.backends.cudnn.benchmark = True
|
39 |
+
self.loss_names = []
|
40 |
+
self.model_names = []
|
41 |
+
self.visual_names = []
|
42 |
+
self.optimizers = []
|
43 |
+
self.image_paths = []
|
44 |
+
self.metric = 0 # used for learning rate policy 'plateau'
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def dict_grad_hook_factory(add_func=lambda x: x):
|
48 |
+
saved_dict = dict()
|
49 |
+
|
50 |
+
def hook_gen(name):
|
51 |
+
def grad_hook(grad):
|
52 |
+
saved_vals = add_func(grad)
|
53 |
+
saved_dict[name] = saved_vals
|
54 |
+
return grad_hook
|
55 |
+
return hook_gen, saved_dict
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def modify_commandline_options(parser, is_train):
|
59 |
+
"""Add new model-specific options, and rewrite default values for existing options.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
parser -- original option parser
|
63 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
the modified parser.
|
67 |
+
"""
|
68 |
+
return parser
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def set_input(self, input):
|
72 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
73 |
+
|
74 |
+
Parameters:
|
75 |
+
input (dict): includes the data itself and its metadata information.
|
76 |
+
"""
|
77 |
+
pass
|
78 |
+
|
79 |
+
@abstractmethod
|
80 |
+
def forward(self):
|
81 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
82 |
+
pass
|
83 |
+
|
84 |
+
@abstractmethod
|
85 |
+
def optimize_parameters(self):
|
86 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
87 |
+
pass
|
88 |
+
|
89 |
+
def setup(self, opt):
|
90 |
+
"""Load and print networks; create schedulers
|
91 |
+
|
92 |
+
Parameters:
|
93 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
94 |
+
"""
|
95 |
+
if self.isTrain:
|
96 |
+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
97 |
+
if not self.isTrain or opt.continue_train:
|
98 |
+
load_suffix = opt.epoch
|
99 |
+
self.load_networks(load_suffix)
|
100 |
+
|
101 |
+
self.print_networks(opt.verbose)
|
102 |
+
|
103 |
+
def parallelize(self):
|
104 |
+
for name in self.model_names:
|
105 |
+
if isinstance(name, str):
|
106 |
+
net = getattr(self, 'net' + name)
|
107 |
+
setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
|
108 |
+
|
109 |
+
def data_dependent_initialize(self, data):
|
110 |
+
pass
|
111 |
+
|
112 |
+
def eval(self):
|
113 |
+
"""Make models eval mode during test time"""
|
114 |
+
for name in self.model_names:
|
115 |
+
if isinstance(name, str):
|
116 |
+
net = getattr(self, 'net' + name)
|
117 |
+
net.eval()
|
118 |
+
|
119 |
+
def test(self):
|
120 |
+
"""Forward function used in test time.
|
121 |
+
|
122 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
123 |
+
It also calls <compute_visuals> to produce additional visualization results
|
124 |
+
"""
|
125 |
+
with torch.no_grad():
|
126 |
+
self.forward()
|
127 |
+
self.compute_visuals()
|
128 |
+
|
129 |
+
def compute_visuals(self):
|
130 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
131 |
+
pass
|
132 |
+
|
133 |
+
def get_image_paths(self):
|
134 |
+
""" Return image paths that are used to load current data"""
|
135 |
+
return self.image_paths
|
136 |
+
|
137 |
+
def update_learning_rate(self):
|
138 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
139 |
+
for scheduler in self.schedulers:
|
140 |
+
if self.opt.lr_policy == 'plateau':
|
141 |
+
scheduler.step(self.metric)
|
142 |
+
else:
|
143 |
+
scheduler.step()
|
144 |
+
|
145 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
146 |
+
print('learning rate = %.7f' % lr)
|
147 |
+
|
148 |
+
def get_current_visuals(self):
|
149 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
150 |
+
visual_ret = OrderedDict()
|
151 |
+
for name in self.visual_names:
|
152 |
+
if isinstance(name, str):
|
153 |
+
visual_ret[name] = getattr(self, name)
|
154 |
+
return visual_ret
|
155 |
+
|
156 |
+
def get_current_losses(self):
|
157 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
158 |
+
errors_ret = OrderedDict()
|
159 |
+
for name in self.loss_names:
|
160 |
+
if isinstance(name, str):
|
161 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
162 |
+
return errors_ret
|
163 |
+
|
164 |
+
def save_networks(self, epoch):
|
165 |
+
"""Save all the networks to the disk.
|
166 |
+
|
167 |
+
Parameters:
|
168 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
169 |
+
"""
|
170 |
+
for name in self.model_names:
|
171 |
+
if isinstance(name, str):
|
172 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
173 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
174 |
+
net = getattr(self, 'net' + name)
|
175 |
+
|
176 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
177 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
178 |
+
net.cuda(self.gpu_ids[0])
|
179 |
+
else:
|
180 |
+
torch.save(net.cpu().state_dict(), save_path)
|
181 |
+
|
182 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
183 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
184 |
+
key = keys[i]
|
185 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
186 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
187 |
+
(key == 'running_mean' or key == 'running_var'):
|
188 |
+
if getattr(module, key) is None:
|
189 |
+
state_dict.pop('.'.join(keys))
|
190 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
191 |
+
(key == 'num_batches_tracked'):
|
192 |
+
state_dict.pop('.'.join(keys))
|
193 |
+
else:
|
194 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
195 |
+
|
196 |
+
def load_networks(self, epoch):
|
197 |
+
"""Load all the networks from the disk.
|
198 |
+
|
199 |
+
Parameters:
|
200 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
201 |
+
"""
|
202 |
+
for name in self.model_names:
|
203 |
+
if isinstance(name, str):
|
204 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
205 |
+
if self.opt.isTrain and self.opt.pretrained_name is not None:
|
206 |
+
load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
|
207 |
+
else:
|
208 |
+
load_dir = self.save_dir
|
209 |
+
|
210 |
+
load_path = os.path.join(load_dir, load_filename)
|
211 |
+
net = getattr(self, 'net' + name)
|
212 |
+
if isinstance(net, torch.nn.DataParallel):
|
213 |
+
net = net.module
|
214 |
+
print('loading the model from %s' % load_path)
|
215 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
216 |
+
# GitHub source), you can remove str() on self.device
|
217 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
218 |
+
if hasattr(state_dict, '_metadata'):
|
219 |
+
del state_dict._metadata
|
220 |
+
|
221 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
222 |
+
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
223 |
+
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
224 |
+
net.load_state_dict(state_dict)
|
225 |
+
|
226 |
+
def print_networks(self, verbose):
|
227 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
verbose (bool) -- if verbose: print the network architecture
|
231 |
+
"""
|
232 |
+
print('---------- Networks initialized -------------')
|
233 |
+
for name in self.model_names:
|
234 |
+
if isinstance(name, str):
|
235 |
+
net = getattr(self, 'net' + name)
|
236 |
+
num_params = 0
|
237 |
+
for param in net.parameters():
|
238 |
+
num_params += param.numel()
|
239 |
+
if verbose:
|
240 |
+
print(net)
|
241 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
242 |
+
print('-----------------------------------------------')
|
243 |
+
|
244 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
245 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
246 |
+
Parameters:
|
247 |
+
nets (network list) -- a list of networks
|
248 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
249 |
+
"""
|
250 |
+
if not isinstance(nets, list):
|
251 |
+
nets = [nets]
|
252 |
+
for net in nets:
|
253 |
+
if net is not None:
|
254 |
+
for param in net.parameters():
|
255 |
+
param.requires_grad = requires_grad
|
256 |
+
|
257 |
+
def generate_visuals_for_evaluation(self, data, mode):
|
258 |
+
return {}
|
asp/models/cpt_model.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from asp.models.asp_loss import AdaptiveSupervisedPatchNCELoss
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from . import networks
|
7 |
+
from .patchnce import PatchNCELoss
|
8 |
+
from .gauss_pyramid import Gauss_Pyramid_Conv
|
9 |
+
import asp.util.util as util
|
10 |
+
|
11 |
+
|
12 |
+
class CPTModel(BaseModel):
|
13 |
+
""" Contrastive Paired Translation (CPT).
|
14 |
+
"""
|
15 |
+
@staticmethod
|
16 |
+
def modify_commandline_options(parser, is_train=True):
|
17 |
+
""" Configures options specific for CUT model
|
18 |
+
"""
|
19 |
+
parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
|
20 |
+
|
21 |
+
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
22 |
+
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
23 |
+
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
24 |
+
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
25 |
+
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
26 |
+
type=util.str2bool, nargs='?', const=True, default=False,
|
27 |
+
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
28 |
+
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
29 |
+
parser.add_argument('--netF_nc', type=int, default=256)
|
30 |
+
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
31 |
+
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
32 |
+
parser.add_argument('--flip_equivariance',
|
33 |
+
type=util.str2bool, nargs='?', const=True, default=False,
|
34 |
+
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
35 |
+
parser.set_defaults(pool_size=0) # no image pooling
|
36 |
+
|
37 |
+
# FDL:
|
38 |
+
parser.add_argument('--lambda_gp', type=float, default=1.0, help='weight for Gaussian Pyramid reconstruction loss')
|
39 |
+
parser.add_argument('--gp_weights', type=str, default='uniform', help='weights for reconstruction pyramids.')
|
40 |
+
parser.add_argument('--lambda_asp', type=float, default=0.0, help='weight for ASP loss')
|
41 |
+
parser.add_argument('--asp_loss_mode', type=str, default='none', help='"scheduler_lookup" options for the ASP loss. Options for both are listed in Fig. 3 of the paper.')
|
42 |
+
parser.add_argument('--n_downsampling', type=int, default=2, help='# of downsample in G')
|
43 |
+
|
44 |
+
opt, _ = parser.parse_known_args()
|
45 |
+
|
46 |
+
# Set default parameters for CUT and FastCUT
|
47 |
+
if opt.CUT_mode.lower() == "cut":
|
48 |
+
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
49 |
+
elif opt.CUT_mode.lower() == "fastcut":
|
50 |
+
parser.set_defaults(
|
51 |
+
nce_idt=False, lambda_NCE=10.0, flip_equivariance=False,
|
52 |
+
n_epochs=20, n_epochs_decay=10
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
raise ValueError(opt.CUT_mode)
|
56 |
+
|
57 |
+
return parser
|
58 |
+
|
59 |
+
def __init__(self, opt):
|
60 |
+
BaseModel.__init__(self, opt)
|
61 |
+
|
62 |
+
# specify the training losses you want to print out.
|
63 |
+
# The training/test scripts will call <BaseModel.get_current_losses>
|
64 |
+
self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE']
|
65 |
+
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
66 |
+
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
67 |
+
|
68 |
+
if opt.nce_idt and self.isTrain:
|
69 |
+
self.loss_names += ['NCE_Y']
|
70 |
+
self.visual_names += ['idt_B']
|
71 |
+
|
72 |
+
if self.isTrain:
|
73 |
+
self.model_names = ['G', 'F', 'D']
|
74 |
+
else: # during test time, only load G
|
75 |
+
self.model_names = ['G']
|
76 |
+
|
77 |
+
# define networks (both generator and discriminator)
|
78 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
79 |
+
self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
80 |
+
|
81 |
+
if self.isTrain:
|
82 |
+
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
83 |
+
|
84 |
+
# define loss functions
|
85 |
+
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
86 |
+
self.criterionNCE = PatchNCELoss(opt).to(self.device)
|
87 |
+
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
88 |
+
|
89 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
90 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
91 |
+
self.optimizers.append(self.optimizer_G)
|
92 |
+
self.optimizers.append(self.optimizer_D)
|
93 |
+
|
94 |
+
if self.opt.lambda_gp > 0:
|
95 |
+
self.P = Gauss_Pyramid_Conv(num_high=5)
|
96 |
+
self.criterionGP = torch.nn.L1Loss().to(self.device)
|
97 |
+
if self.opt.gp_weights == 'uniform':
|
98 |
+
self.gp_weights = [1.0] * 6
|
99 |
+
else:
|
100 |
+
self.gp_weights = eval(self.opt.gp_weights)
|
101 |
+
self.loss_names += ['GP']
|
102 |
+
|
103 |
+
if self.opt.lambda_asp > 0:
|
104 |
+
self.criterionASP = AdaptiveSupervisedPatchNCELoss(self.opt).to(self.device)
|
105 |
+
self.loss_names += ['ASP']
|
106 |
+
|
107 |
+
|
108 |
+
def data_dependent_initialize(self, data):
|
109 |
+
"""
|
110 |
+
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
111 |
+
features of the encoder portion of netG. Because of this, the weights of netF are
|
112 |
+
initialized at the first feedforward pass with some input images.
|
113 |
+
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
114 |
+
"""
|
115 |
+
bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
|
116 |
+
self.set_input(data)
|
117 |
+
self.real_A = self.real_A[:bs_per_gpu]
|
118 |
+
self.real_B = self.real_B[:bs_per_gpu]
|
119 |
+
self.forward() # compute fake images: G(A)
|
120 |
+
if self.opt.isTrain:
|
121 |
+
self.compute_D_loss().backward() # calculate gradients for D
|
122 |
+
self.compute_G_loss().backward() # calculate graidents for G
|
123 |
+
if self.opt.lambda_NCE > 0.0 or self.opt.lambda_asp > 0.0:
|
124 |
+
self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
125 |
+
self.optimizers.append(self.optimizer_F)
|
126 |
+
|
127 |
+
def optimize_parameters(self):
|
128 |
+
# forward
|
129 |
+
self.forward()
|
130 |
+
|
131 |
+
# update D
|
132 |
+
self.set_requires_grad(self.netD, True)
|
133 |
+
self.optimizer_D.zero_grad()
|
134 |
+
self.loss_D = self.compute_D_loss()
|
135 |
+
self.loss_D.backward()
|
136 |
+
self.optimizer_D.step()
|
137 |
+
# update G
|
138 |
+
self.set_requires_grad(self.netD, False)
|
139 |
+
self.optimizer_G.zero_grad()
|
140 |
+
if self.opt.netF == 'mlp_sample':
|
141 |
+
self.optimizer_F.zero_grad()
|
142 |
+
self.loss_G = self.compute_G_loss()
|
143 |
+
self.loss_G.backward()
|
144 |
+
self.optimizer_G.step()
|
145 |
+
if self.opt.netF == 'mlp_sample':
|
146 |
+
self.optimizer_F.step()
|
147 |
+
|
148 |
+
def set_input(self, input):
|
149 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
150 |
+
Parameters:
|
151 |
+
input (dict): include the data itself and its metadata information.
|
152 |
+
The option 'direction' can be used to swap domain A and domain B.
|
153 |
+
"""
|
154 |
+
AtoB = self.opt.direction == 'AtoB'
|
155 |
+
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
156 |
+
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
157 |
+
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
158 |
+
|
159 |
+
if 'current_epoch' in input:
|
160 |
+
self.current_epoch = input['current_epoch']
|
161 |
+
if 'current_iter' in input:
|
162 |
+
self.current_iter = input['current_iter']
|
163 |
+
|
164 |
+
def forward(self):
|
165 |
+
# self.netG.print()
|
166 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
167 |
+
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
|
168 |
+
if self.opt.flip_equivariance:
|
169 |
+
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
170 |
+
if self.flipped_for_equivariance:
|
171 |
+
self.real = torch.flip(self.real, [3])
|
172 |
+
|
173 |
+
self.fake = self.netG(self.real, layers=[])
|
174 |
+
self.fake_B = self.fake[:self.real_A.size(0)]
|
175 |
+
if self.opt.nce_idt:
|
176 |
+
self.idt_B = self.fake[self.real_A.size(0):]
|
177 |
+
|
178 |
+
def compute_D_loss(self):
|
179 |
+
"""Calculate GAN loss for the discriminator"""
|
180 |
+
fake = self.fake_B.detach()
|
181 |
+
# Fake; stop backprop to the generator by detaching fake_B
|
182 |
+
pred_fake = self.netD(fake)
|
183 |
+
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
|
184 |
+
# Real
|
185 |
+
self.pred_real = self.netD(self.real_B)
|
186 |
+
loss_D_real = self.criterionGAN(self.pred_real, True)
|
187 |
+
self.loss_D_real = loss_D_real.mean()
|
188 |
+
|
189 |
+
# combine loss and calculate gradients
|
190 |
+
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
191 |
+
return self.loss_D
|
192 |
+
|
193 |
+
def compute_G_loss(self):
|
194 |
+
"""Calculate GAN and NCE loss for the generator"""
|
195 |
+
fake = self.fake_B
|
196 |
+
|
197 |
+
feat_real_A = self.netG(self.real_A, self.nce_layers, encode_only=True)
|
198 |
+
feat_fake_B = self.netG(self.fake_B, self.nce_layers, encode_only=True)
|
199 |
+
feat_real_B = self.netG(self.real_B, self.nce_layers, encode_only=True)
|
200 |
+
if self.opt.nce_idt:
|
201 |
+
feat_idt_B = self.netG(self.idt_B, self.nce_layers, encode_only=True)
|
202 |
+
|
203 |
+
# First, G(A) should fake the discriminator
|
204 |
+
if self.opt.lambda_GAN > 0.0:
|
205 |
+
pred_fake = self.netD(fake)
|
206 |
+
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
207 |
+
else:
|
208 |
+
self.loss_G_GAN = 0.0
|
209 |
+
|
210 |
+
if self.opt.lambda_NCE > 0.0:
|
211 |
+
self.loss_NCE = self.calculate_NCE_loss(feat_real_A, feat_fake_B, self.netF, self.nce_layers)
|
212 |
+
else:
|
213 |
+
self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
|
214 |
+
loss_NCE_all = self.loss_NCE
|
215 |
+
|
216 |
+
if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
|
217 |
+
self.loss_NCE_Y = self.calculate_NCE_loss(feat_real_B, feat_idt_B, self.netF, self.nce_layers)
|
218 |
+
else:
|
219 |
+
self.loss_NCE_Y = 0.0
|
220 |
+
loss_NCE_all += self.loss_NCE_Y
|
221 |
+
|
222 |
+
# FDL: NCE between the noisy pairs (fake_B and real_B)
|
223 |
+
if self.opt.lambda_asp > 0:
|
224 |
+
self.loss_ASP = self.calculate_NCE_loss(feat_real_B, feat_fake_B, self.netF, self.nce_layers, paired=True)
|
225 |
+
else:
|
226 |
+
self.loss_ASP = 0.0
|
227 |
+
loss_NCE_all += self.loss_ASP
|
228 |
+
|
229 |
+
# FDL: compute loss on Gaussian pyramids
|
230 |
+
if self.opt.lambda_gp > 0:
|
231 |
+
p_fake_B = self.P(self.fake_B)
|
232 |
+
p_real_B = self.P(self.real_B)
|
233 |
+
loss_pyramid = [self.criterionGP(pf, pr) for pf, pr in zip(p_fake_B, p_real_B)]
|
234 |
+
weights = self.gp_weights
|
235 |
+
loss_pyramid = [l * w for l, w in zip(loss_pyramid, weights)]
|
236 |
+
self.loss_GP = torch.mean(torch.stack(loss_pyramid)) * self.opt.lambda_gp
|
237 |
+
else:
|
238 |
+
self.loss_GP = 0
|
239 |
+
|
240 |
+
self.loss_G = self.loss_G_GAN + loss_NCE_all + self.loss_GP
|
241 |
+
return self.loss_G
|
242 |
+
|
243 |
+
def calculate_NCE_loss(self, feat_src, feat_tgt, netF, nce_layers, paired=False):
|
244 |
+
n_layers = len(feat_src)
|
245 |
+
feat_q = feat_tgt
|
246 |
+
|
247 |
+
if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
248 |
+
feat_q = [torch.flip(fq, [3]) for fq in feat_q]
|
249 |
+
feat_k = feat_src
|
250 |
+
feat_k_pool, sample_ids = netF(feat_k, self.opt.num_patches, None)
|
251 |
+
feat_q_pool, _ = netF(feat_q, self.opt.num_patches, sample_ids)
|
252 |
+
|
253 |
+
total_nce_loss = 0.0
|
254 |
+
for f_q, f_k in zip(feat_q_pool, feat_k_pool):
|
255 |
+
if paired:
|
256 |
+
loss = self.criterionASP(f_q, f_k, self.current_epoch) * self.opt.lambda_asp
|
257 |
+
else:
|
258 |
+
loss = self.criterionNCE(f_q, f_k) * self.opt.lambda_NCE
|
259 |
+
total_nce_loss += loss.mean()
|
260 |
+
|
261 |
+
return total_nce_loss / n_layers
|
asp/models/cut_model.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from .base_model import BaseModel
|
4 |
+
from . import networks
|
5 |
+
from .patchnce import PatchNCELoss
|
6 |
+
import util.util as util
|
7 |
+
|
8 |
+
|
9 |
+
class CUTModel(BaseModel):
|
10 |
+
""" This class implements CUT and FastCUT model, described in the paper
|
11 |
+
Contrastive Learning for Unpaired Image-to-Image Translation
|
12 |
+
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
|
13 |
+
ECCV, 2020
|
14 |
+
|
15 |
+
The code borrows heavily from the PyTorch implementation of CycleGAN
|
16 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
17 |
+
"""
|
18 |
+
@staticmethod
|
19 |
+
def modify_commandline_options(parser, is_train=True):
|
20 |
+
""" Configures options specific for CUT model
|
21 |
+
"""
|
22 |
+
parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
|
23 |
+
|
24 |
+
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
|
25 |
+
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
26 |
+
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
27 |
+
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
28 |
+
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
29 |
+
type=util.str2bool, nargs='?', const=True, default=False,
|
30 |
+
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
31 |
+
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
32 |
+
parser.add_argument('--netF_nc', type=int, default=256)
|
33 |
+
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
34 |
+
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
35 |
+
parser.add_argument('--flip_equivariance',
|
36 |
+
type=util.str2bool, nargs='?', const=True, default=False,
|
37 |
+
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
38 |
+
|
39 |
+
parser.set_defaults(pool_size=0) # no image pooling
|
40 |
+
|
41 |
+
opt, _ = parser.parse_known_args()
|
42 |
+
|
43 |
+
# Set default parameters for CUT and FastCUT
|
44 |
+
if opt.CUT_mode.lower() == "cut":
|
45 |
+
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
46 |
+
elif opt.CUT_mode.lower() == "fastcut":
|
47 |
+
parser.set_defaults(
|
48 |
+
nce_idt=False, lambda_NCE=10.0, flip_equivariance=True,
|
49 |
+
n_epochs=150, n_epochs_decay=50
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
raise ValueError(opt.CUT_mode)
|
53 |
+
|
54 |
+
return parser
|
55 |
+
|
56 |
+
def __init__(self, opt):
|
57 |
+
BaseModel.__init__(self, opt)
|
58 |
+
|
59 |
+
# specify the training losses you want to print out.
|
60 |
+
# The training/test scripts will call <BaseModel.get_current_losses>
|
61 |
+
self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE']
|
62 |
+
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
63 |
+
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
64 |
+
|
65 |
+
if opt.nce_idt and self.isTrain:
|
66 |
+
self.loss_names += ['NCE_Y']
|
67 |
+
self.visual_names += ['idt_B']
|
68 |
+
|
69 |
+
if self.isTrain:
|
70 |
+
self.model_names = ['G', 'F', 'D']
|
71 |
+
else: # during test time, only load G
|
72 |
+
self.model_names = ['G']
|
73 |
+
|
74 |
+
# define networks (both generator and discriminator)
|
75 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
76 |
+
self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
77 |
+
|
78 |
+
if self.isTrain:
|
79 |
+
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
80 |
+
|
81 |
+
# define loss functions
|
82 |
+
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
83 |
+
self.criterionNCE = []
|
84 |
+
|
85 |
+
for nce_layer in self.nce_layers:
|
86 |
+
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
87 |
+
|
88 |
+
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
89 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
90 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
91 |
+
self.optimizers.append(self.optimizer_G)
|
92 |
+
self.optimizers.append(self.optimizer_D)
|
93 |
+
|
94 |
+
def data_dependent_initialize(self, data):
|
95 |
+
"""
|
96 |
+
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
97 |
+
features of the encoder portion of netG. Because of this, the weights of netF are
|
98 |
+
initialized at the first feedforward pass with some input images.
|
99 |
+
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
100 |
+
"""
|
101 |
+
bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
|
102 |
+
self.set_input(data)
|
103 |
+
self.real_A = self.real_A[:bs_per_gpu]
|
104 |
+
self.real_B = self.real_B[:bs_per_gpu]
|
105 |
+
self.forward() # compute fake images: G(A)
|
106 |
+
if self.opt.isTrain:
|
107 |
+
self.compute_D_loss().backward() # calculate gradients for D
|
108 |
+
self.compute_G_loss().backward() # calculate graidents for G
|
109 |
+
if self.opt.lambda_NCE > 0.0:
|
110 |
+
self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
111 |
+
self.optimizers.append(self.optimizer_F)
|
112 |
+
|
113 |
+
def optimize_parameters(self):
|
114 |
+
# forward
|
115 |
+
self.forward()
|
116 |
+
|
117 |
+
# update D
|
118 |
+
self.set_requires_grad(self.netD, True)
|
119 |
+
self.optimizer_D.zero_grad()
|
120 |
+
self.loss_D = self.compute_D_loss()
|
121 |
+
self.loss_D.backward()
|
122 |
+
self.optimizer_D.step()
|
123 |
+
|
124 |
+
# update G
|
125 |
+
self.set_requires_grad(self.netD, False)
|
126 |
+
self.optimizer_G.zero_grad()
|
127 |
+
if self.opt.netF == 'mlp_sample':
|
128 |
+
self.optimizer_F.zero_grad()
|
129 |
+
self.loss_G = self.compute_G_loss()
|
130 |
+
self.loss_G.backward()
|
131 |
+
self.optimizer_G.step()
|
132 |
+
if self.opt.netF == 'mlp_sample':
|
133 |
+
self.optimizer_F.step()
|
134 |
+
|
135 |
+
def set_input(self, input):
|
136 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
137 |
+
Parameters:
|
138 |
+
input (dict): include the data itself and its metadata information.
|
139 |
+
The option 'direction' can be used to swap domain A and domain B.
|
140 |
+
"""
|
141 |
+
AtoB = self.opt.direction == 'AtoB'
|
142 |
+
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
143 |
+
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
144 |
+
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
145 |
+
|
146 |
+
def forward(self):
|
147 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
148 |
+
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
|
149 |
+
if self.opt.flip_equivariance:
|
150 |
+
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
151 |
+
if self.flipped_for_equivariance:
|
152 |
+
self.real = torch.flip(self.real, [3])
|
153 |
+
|
154 |
+
self.fake = self.netG(self.real)
|
155 |
+
self.fake_B = self.fake[:self.real_A.size(0)]
|
156 |
+
if self.opt.nce_idt:
|
157 |
+
self.idt_B = self.fake[self.real_A.size(0):]
|
158 |
+
|
159 |
+
def compute_D_loss(self):
|
160 |
+
"""Calculate GAN loss for the discriminator"""
|
161 |
+
fake = self.fake_B.detach()
|
162 |
+
# Fake; stop backprop to the generator by detaching fake_B
|
163 |
+
pred_fake = self.netD(fake)
|
164 |
+
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
|
165 |
+
# Real
|
166 |
+
self.pred_real = self.netD(self.real_B)
|
167 |
+
loss_D_real = self.criterionGAN(self.pred_real, True)
|
168 |
+
self.loss_D_real = loss_D_real.mean()
|
169 |
+
|
170 |
+
# combine loss and calculate gradients
|
171 |
+
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
172 |
+
return self.loss_D
|
173 |
+
|
174 |
+
def compute_G_loss(self):
|
175 |
+
"""Calculate GAN and NCE loss for the generator"""
|
176 |
+
fake = self.fake_B
|
177 |
+
# First, G(A) should fake the discriminator
|
178 |
+
if self.opt.lambda_GAN > 0.0:
|
179 |
+
pred_fake = self.netD(fake)
|
180 |
+
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
181 |
+
else:
|
182 |
+
self.loss_G_GAN = 0.0
|
183 |
+
|
184 |
+
if self.opt.lambda_NCE > 0.0:
|
185 |
+
self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
|
186 |
+
else:
|
187 |
+
self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
|
188 |
+
|
189 |
+
if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
|
190 |
+
self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
|
191 |
+
loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
|
192 |
+
else:
|
193 |
+
loss_NCE_both = self.loss_NCE
|
194 |
+
|
195 |
+
self.loss_G = self.loss_G_GAN + loss_NCE_both
|
196 |
+
return self.loss_G
|
197 |
+
|
198 |
+
def calculate_NCE_loss(self, src, tgt):
|
199 |
+
n_layers = len(self.nce_layers)
|
200 |
+
feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
|
201 |
+
|
202 |
+
if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
203 |
+
feat_q = [torch.flip(fq, [3]) for fq in feat_q]
|
204 |
+
|
205 |
+
feat_k = self.netG(src, self.nce_layers, encode_only=True)
|
206 |
+
feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
|
207 |
+
feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
|
208 |
+
|
209 |
+
total_nce_loss = 0.0
|
210 |
+
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
|
211 |
+
loss = crit(f_q, f_k) * self.opt.lambda_NCE
|
212 |
+
total_nce_loss += loss.mean()
|
213 |
+
|
214 |
+
return total_nce_loss / n_layers
|
asp/models/gauss_pyramid.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class Gauss_Pyramid_Conv(nn.Module):
|
5 |
+
"""
|
6 |
+
Code borrowed from: https://github.com/csjliang/LPTN
|
7 |
+
"""
|
8 |
+
def __init__(self, num_high=3):
|
9 |
+
super(Gauss_Pyramid_Conv, self).__init__()
|
10 |
+
|
11 |
+
self.num_high = num_high
|
12 |
+
self.kernel = self.gauss_kernel()
|
13 |
+
|
14 |
+
def gauss_kernel(self, device=torch.device('cuda'), channels=3):
|
15 |
+
kernel = torch.tensor([[1., 4., 6., 4., 1],
|
16 |
+
[4., 16., 24., 16., 4.],
|
17 |
+
[6., 24., 36., 24., 6.],
|
18 |
+
[4., 16., 24., 16., 4.],
|
19 |
+
[1., 4., 6., 4., 1.]])
|
20 |
+
kernel /= 256.
|
21 |
+
kernel = kernel.repeat(channels, 1, 1, 1)
|
22 |
+
kernel = kernel.to(device)
|
23 |
+
return kernel
|
24 |
+
|
25 |
+
def downsample(self, x):
|
26 |
+
return x[:, :, ::2, ::2]
|
27 |
+
|
28 |
+
def conv_gauss(self, img, kernel):
|
29 |
+
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
|
30 |
+
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
|
31 |
+
return out
|
32 |
+
|
33 |
+
def forward(self, img):
|
34 |
+
current = img
|
35 |
+
pyr = []
|
36 |
+
for _ in range(self.num_high):
|
37 |
+
filtered = self.conv_gauss(current, self.kernel)
|
38 |
+
pyr.append(filtered)
|
39 |
+
down = self.downsample(filtered)
|
40 |
+
current = down
|
41 |
+
pyr.append(current)
|
42 |
+
return pyr
|
asp/models/networks.py
ADDED
@@ -0,0 +1,1422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import copy
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import init
|
6 |
+
import functools
|
7 |
+
from torch.optim import lr_scheduler
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
###############################################################################
|
11 |
+
# Helper Functions
|
12 |
+
###############################################################################
|
13 |
+
|
14 |
+
|
15 |
+
def get_filter(filt_size=3):
|
16 |
+
if(filt_size == 1):
|
17 |
+
a = np.array([1., ])
|
18 |
+
elif(filt_size == 2):
|
19 |
+
a = np.array([1., 1.])
|
20 |
+
elif(filt_size == 3):
|
21 |
+
a = np.array([1., 2., 1.])
|
22 |
+
elif(filt_size == 4):
|
23 |
+
a = np.array([1., 3., 3., 1.])
|
24 |
+
elif(filt_size == 5):
|
25 |
+
a = np.array([1., 4., 6., 4., 1.])
|
26 |
+
elif(filt_size == 6):
|
27 |
+
a = np.array([1., 5., 10., 10., 5., 1.])
|
28 |
+
elif(filt_size == 7):
|
29 |
+
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
30 |
+
|
31 |
+
filt = torch.Tensor(a[:, None] * a[None, :])
|
32 |
+
filt = filt / torch.sum(filt)
|
33 |
+
|
34 |
+
return filt
|
35 |
+
|
36 |
+
|
37 |
+
class Downsample(nn.Module):
|
38 |
+
def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
|
39 |
+
super(Downsample, self).__init__()
|
40 |
+
self.filt_size = filt_size
|
41 |
+
self.pad_off = pad_off
|
42 |
+
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
|
43 |
+
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
44 |
+
self.stride = stride
|
45 |
+
self.off = int((self.stride - 1) / 2.)
|
46 |
+
self.channels = channels
|
47 |
+
|
48 |
+
filt = get_filter(filt_size=self.filt_size)
|
49 |
+
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
50 |
+
|
51 |
+
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
52 |
+
|
53 |
+
def forward(self, inp):
|
54 |
+
if(self.filt_size == 1):
|
55 |
+
if(self.pad_off == 0):
|
56 |
+
return inp[:, :, ::self.stride, ::self.stride]
|
57 |
+
else:
|
58 |
+
return self.pad(inp)[:, :, ::self.stride, ::self.stride]
|
59 |
+
else:
|
60 |
+
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
61 |
+
|
62 |
+
|
63 |
+
class Upsample2(nn.Module):
|
64 |
+
def __init__(self, scale_factor, mode='nearest'):
|
65 |
+
super().__init__()
|
66 |
+
self.factor = scale_factor
|
67 |
+
self.mode = mode
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
|
71 |
+
|
72 |
+
|
73 |
+
class Upsample(nn.Module):
|
74 |
+
def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
|
75 |
+
super(Upsample, self).__init__()
|
76 |
+
self.filt_size = filt_size
|
77 |
+
self.filt_odd = np.mod(filt_size, 2) == 1
|
78 |
+
self.pad_size = int((filt_size - 1) / 2)
|
79 |
+
self.stride = stride
|
80 |
+
self.off = int((self.stride - 1) / 2.)
|
81 |
+
self.channels = channels
|
82 |
+
|
83 |
+
filt = get_filter(filt_size=self.filt_size) * (stride**2)
|
84 |
+
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
85 |
+
|
86 |
+
self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
|
87 |
+
|
88 |
+
def forward(self, inp):
|
89 |
+
ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
|
90 |
+
if(self.filt_odd):
|
91 |
+
return ret_val
|
92 |
+
else:
|
93 |
+
return ret_val[:, :, :-1, :-1]
|
94 |
+
|
95 |
+
|
96 |
+
def get_pad_layer(pad_type):
|
97 |
+
if(pad_type in ['refl', 'reflect']):
|
98 |
+
PadLayer = nn.ReflectionPad2d
|
99 |
+
elif(pad_type in ['repl', 'replicate']):
|
100 |
+
PadLayer = nn.ReplicationPad2d
|
101 |
+
elif(pad_type == 'zero'):
|
102 |
+
PadLayer = nn.ZeroPad2d
|
103 |
+
else:
|
104 |
+
print('Pad type [%s] not recognized' % pad_type)
|
105 |
+
return PadLayer
|
106 |
+
|
107 |
+
|
108 |
+
class Identity(nn.Module):
|
109 |
+
def forward(self, x):
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
def get_norm_layer(norm_type='instance'):
|
114 |
+
"""Return a normalization layer
|
115 |
+
|
116 |
+
Parameters:
|
117 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
118 |
+
|
119 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
120 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
121 |
+
"""
|
122 |
+
if norm_type == 'batch':
|
123 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
124 |
+
elif norm_type == 'instance':
|
125 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
126 |
+
elif norm_type == 'none':
|
127 |
+
def norm_layer(x):
|
128 |
+
return Identity()
|
129 |
+
else:
|
130 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
131 |
+
return norm_layer
|
132 |
+
|
133 |
+
|
134 |
+
def get_scheduler(optimizer, opt):
|
135 |
+
"""Return a learning rate scheduler
|
136 |
+
|
137 |
+
Parameters:
|
138 |
+
optimizer -- the optimizer of the network
|
139 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
140 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
141 |
+
|
142 |
+
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
143 |
+
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
144 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
145 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
146 |
+
"""
|
147 |
+
if opt.lr_policy == 'linear':
|
148 |
+
def lambda_rule(epoch):
|
149 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
150 |
+
return lr_l
|
151 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
152 |
+
elif opt.lr_policy == 'step':
|
153 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
154 |
+
elif opt.lr_policy == 'plateau':
|
155 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
156 |
+
elif opt.lr_policy == 'cosine':
|
157 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
158 |
+
else:
|
159 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
160 |
+
return scheduler
|
161 |
+
|
162 |
+
|
163 |
+
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
|
164 |
+
"""Initialize network weights.
|
165 |
+
|
166 |
+
Parameters:
|
167 |
+
net (network) -- network to be initialized
|
168 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
169 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
170 |
+
|
171 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
172 |
+
work better for some applications. Feel free to try yourself.
|
173 |
+
"""
|
174 |
+
def init_func(m): # define the initialization function
|
175 |
+
classname = m.__class__.__name__
|
176 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
177 |
+
if debug:
|
178 |
+
print(classname)
|
179 |
+
if init_type == 'normal':
|
180 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
181 |
+
elif init_type == 'xavier':
|
182 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
183 |
+
elif init_type == 'kaiming':
|
184 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
185 |
+
elif init_type == 'orthogonal':
|
186 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
187 |
+
else:
|
188 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
189 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
190 |
+
init.constant_(m.bias.data, 0.0)
|
191 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
192 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
193 |
+
init.constant_(m.bias.data, 0.0)
|
194 |
+
|
195 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
196 |
+
|
197 |
+
|
198 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
|
199 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
200 |
+
Parameters:
|
201 |
+
net (network) -- the network to be initialized
|
202 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
203 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
204 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
205 |
+
|
206 |
+
Return an initialized network.
|
207 |
+
"""
|
208 |
+
if len(gpu_ids) > 0:
|
209 |
+
assert(torch.cuda.is_available())
|
210 |
+
net.to(gpu_ids[0])
|
211 |
+
# if not amp:
|
212 |
+
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
|
213 |
+
if initialize_weights:
|
214 |
+
init_weights(net, init_type, init_gain=init_gain, debug=debug)
|
215 |
+
return net
|
216 |
+
|
217 |
+
|
218 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal',
|
219 |
+
init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None):
|
220 |
+
"""Create a generator
|
221 |
+
|
222 |
+
Parameters:
|
223 |
+
input_nc (int) -- the number of channels in input images
|
224 |
+
output_nc (int) -- the number of channels in output images
|
225 |
+
ngf (int) -- the number of filters in the last conv layer
|
226 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
227 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
228 |
+
use_dropout (bool) -- if use dropout layers.
|
229 |
+
init_type (str) -- the name of our initialization method.
|
230 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
231 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
232 |
+
|
233 |
+
Returns a generator
|
234 |
+
|
235 |
+
Our current implementation provides two types of generators:
|
236 |
+
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
|
237 |
+
The original U-Net paper: https://arxiv.org/abs/1505.04597
|
238 |
+
|
239 |
+
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
|
240 |
+
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
|
241 |
+
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
|
242 |
+
|
243 |
+
|
244 |
+
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
|
245 |
+
"""
|
246 |
+
net = None
|
247 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
248 |
+
|
249 |
+
if netG == 'resnet_9blocks':
|
250 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
|
251 |
+
elif netG == 'resnet_6blocks':
|
252 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=6, opt=opt)
|
253 |
+
elif netG == 'resnet_4blocks':
|
254 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=4, opt=opt)
|
255 |
+
elif netG == 'unet_128':
|
256 |
+
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
257 |
+
elif netG == 'unet_256':
|
258 |
+
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
259 |
+
elif netG == 'resnet_cat':
|
260 |
+
n_blocks = 8
|
261 |
+
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
|
262 |
+
else:
|
263 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
264 |
+
return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
|
265 |
+
|
266 |
+
|
267 |
+
def define_F(input_nc, netF, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
|
268 |
+
if netF == 'global_pool':
|
269 |
+
net = PoolingF()
|
270 |
+
elif netF == 'reshape':
|
271 |
+
net = ReshapeF()
|
272 |
+
elif netF == 'sample':
|
273 |
+
net = PatchSampleF(use_mlp=False, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc, opt=opt)
|
274 |
+
elif netF == 'mlp_sample':
|
275 |
+
net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc, opt=opt)
|
276 |
+
elif netF == 'strided_conv':
|
277 |
+
net = StridedConvF(init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
|
278 |
+
else:
|
279 |
+
raise NotImplementedError('projection model name [%s] is not recognized' % netF)
|
280 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
281 |
+
|
282 |
+
|
283 |
+
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
|
284 |
+
"""Create a discriminator
|
285 |
+
|
286 |
+
Parameters:
|
287 |
+
input_nc (int) -- the number of channels in input images
|
288 |
+
ndf (int) -- the number of filters in the first conv layer
|
289 |
+
netD (str) -- the architecture's name: basic | n_layers | pixel
|
290 |
+
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
|
291 |
+
norm (str) -- the type of normalization layers used in the network.
|
292 |
+
init_type (str) -- the name of the initialization method.
|
293 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
294 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
295 |
+
|
296 |
+
Returns a discriminator
|
297 |
+
|
298 |
+
Our current implementation provides three types of discriminators:
|
299 |
+
[basic]: 'PatchGAN' classifier described in the original pix2pix paper.
|
300 |
+
It can classify whether 70x70 overlapping patches are real or fake.
|
301 |
+
Such a patch-level discriminator architecture has fewer parameters
|
302 |
+
than a full-image discriminator and can work on arbitrarily-sized images
|
303 |
+
in a fully convolutional fashion.
|
304 |
+
|
305 |
+
[n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
|
306 |
+
with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
|
307 |
+
|
308 |
+
[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
|
309 |
+
It encourages greater color diversity but has no effect on spatial statistics.
|
310 |
+
|
311 |
+
The discriminator has been initialized by <init_net>. It uses Leaky RELU for non-linearity.
|
312 |
+
"""
|
313 |
+
net = None
|
314 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
315 |
+
|
316 |
+
if netD == 'basic': # default PatchGAN classifier
|
317 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias, opt=opt)
|
318 |
+
elif netD == 'n_layers': # more options
|
319 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias, opt=opt)
|
320 |
+
elif netD == 'pixel': # classify if each pixel is real or fake
|
321 |
+
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
322 |
+
else:
|
323 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
324 |
+
return init_net(net, init_type, init_gain, gpu_ids,
|
325 |
+
initialize_weights=('stylegan2' not in netD))
|
326 |
+
|
327 |
+
|
328 |
+
##############################################################################
|
329 |
+
# Classes
|
330 |
+
##############################################################################
|
331 |
+
class GANLoss(nn.Module):
|
332 |
+
"""Define different GAN objectives.
|
333 |
+
|
334 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
335 |
+
that has the same size as the input.
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
339 |
+
""" Initialize the GANLoss class.
|
340 |
+
|
341 |
+
Parameters:
|
342 |
+
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
343 |
+
target_real_label (bool) - - label for a real image
|
344 |
+
target_fake_label (bool) - - label of a fake image
|
345 |
+
|
346 |
+
Note: Do not use sigmoid as the last layer of Discriminator.
|
347 |
+
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
348 |
+
"""
|
349 |
+
super(GANLoss, self).__init__()
|
350 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
351 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
352 |
+
self.gan_mode = gan_mode
|
353 |
+
if gan_mode == 'lsgan':
|
354 |
+
self.loss = nn.MSELoss()
|
355 |
+
elif gan_mode == 'vanilla':
|
356 |
+
self.loss = nn.BCEWithLogitsLoss()
|
357 |
+
elif gan_mode in ['wgangp', 'nonsaturating']:
|
358 |
+
self.loss = None
|
359 |
+
else:
|
360 |
+
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
361 |
+
|
362 |
+
def get_target_tensor(self, prediction, target_is_real):
|
363 |
+
"""Create label tensors with the same size as the input.
|
364 |
+
|
365 |
+
Parameters:
|
366 |
+
prediction (tensor) - - tpyically the prediction from a discriminator
|
367 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
A label tensor filled with ground truth label, and with the size of the input
|
371 |
+
"""
|
372 |
+
|
373 |
+
if target_is_real:
|
374 |
+
target_tensor = self.real_label
|
375 |
+
else:
|
376 |
+
target_tensor = self.fake_label
|
377 |
+
return target_tensor.expand_as(prediction)
|
378 |
+
|
379 |
+
def __call__(self, prediction, target_is_real):
|
380 |
+
"""Calculate loss given Discriminator's output and grount truth labels.
|
381 |
+
|
382 |
+
Parameters:
|
383 |
+
prediction (tensor) - - tpyically the prediction output from a discriminator
|
384 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
385 |
+
|
386 |
+
Returns:
|
387 |
+
the calculated loss.
|
388 |
+
"""
|
389 |
+
bs = prediction.size(0)
|
390 |
+
if self.gan_mode in ['lsgan', 'vanilla']:
|
391 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
392 |
+
loss = self.loss(prediction, target_tensor)
|
393 |
+
elif self.gan_mode == 'wgangp':
|
394 |
+
if target_is_real:
|
395 |
+
loss = -prediction.mean()
|
396 |
+
else:
|
397 |
+
loss = prediction.mean()
|
398 |
+
elif self.gan_mode == 'nonsaturating':
|
399 |
+
if target_is_real:
|
400 |
+
loss = F.softplus(-prediction).view(bs, -1).mean(dim=1)
|
401 |
+
else:
|
402 |
+
loss = F.softplus(prediction).view(bs, -1).mean(dim=1)
|
403 |
+
return loss
|
404 |
+
|
405 |
+
|
406 |
+
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
407 |
+
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
408 |
+
|
409 |
+
Arguments:
|
410 |
+
netD (network) -- discriminator network
|
411 |
+
real_data (tensor array) -- real images
|
412 |
+
fake_data (tensor array) -- generated images from the generator
|
413 |
+
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
414 |
+
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
415 |
+
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
|
416 |
+
lambda_gp (float) -- weight for this loss
|
417 |
+
|
418 |
+
Returns the gradient penalty loss
|
419 |
+
"""
|
420 |
+
if lambda_gp > 0.0:
|
421 |
+
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
422 |
+
interpolatesv = real_data
|
423 |
+
elif type == 'fake':
|
424 |
+
interpolatesv = fake_data
|
425 |
+
elif type == 'mixed':
|
426 |
+
alpha = torch.rand(real_data.shape[0], 1, device=device)
|
427 |
+
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
428 |
+
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
429 |
+
else:
|
430 |
+
raise NotImplementedError('{} not implemented'.format(type))
|
431 |
+
interpolatesv.requires_grad_(True)
|
432 |
+
disc_interpolates = netD(interpolatesv)
|
433 |
+
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
434 |
+
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
435 |
+
create_graph=True, retain_graph=True, only_inputs=True)
|
436 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
437 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
438 |
+
return gradient_penalty, gradients
|
439 |
+
else:
|
440 |
+
return 0.0, None
|
441 |
+
|
442 |
+
|
443 |
+
class Normalize(nn.Module):
|
444 |
+
|
445 |
+
def __init__(self, power=2):
|
446 |
+
super(Normalize, self).__init__()
|
447 |
+
self.power = power
|
448 |
+
|
449 |
+
def forward(self, x, dim=1):
|
450 |
+
# norm = x.pow(self.power).sum(dim, keepdim=True).pow(1. / self.power)
|
451 |
+
# out = x.div(norm + 1e-7)
|
452 |
+
# FDL: To avoid sqrting 0s, which causes nans in grad
|
453 |
+
norm = (x + 1e-7).pow(self.power).sum(dim, keepdim=True).pow(1. / self.power)
|
454 |
+
out = x.div(norm)
|
455 |
+
return out
|
456 |
+
|
457 |
+
|
458 |
+
class PoolingF(nn.Module):
|
459 |
+
def __init__(self):
|
460 |
+
super(PoolingF, self).__init__()
|
461 |
+
model = [nn.AdaptiveMaxPool2d(1)]
|
462 |
+
self.model = nn.Sequential(*model)
|
463 |
+
self.l2norm = Normalize(2)
|
464 |
+
|
465 |
+
def forward(self, x):
|
466 |
+
return self.l2norm(self.model(x))
|
467 |
+
|
468 |
+
|
469 |
+
class ReshapeF(nn.Module):
|
470 |
+
def __init__(self):
|
471 |
+
super(ReshapeF, self).__init__()
|
472 |
+
model = [nn.AdaptiveAvgPool2d(4)]
|
473 |
+
self.model = nn.Sequential(*model)
|
474 |
+
self.l2norm = Normalize(2)
|
475 |
+
|
476 |
+
def forward(self, x):
|
477 |
+
x = self.model(x)
|
478 |
+
x_reshape = x.permute(0, 2, 3, 1).flatten(0, 2)
|
479 |
+
return self.l2norm(x_reshape)
|
480 |
+
|
481 |
+
|
482 |
+
class StridedConvF(nn.Module):
|
483 |
+
def __init__(self, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
484 |
+
super().__init__()
|
485 |
+
# self.conv1 = nn.Conv2d(256, 128, 3, stride=2)
|
486 |
+
# self.conv2 = nn.Conv2d(128, 64, 3, stride=1)
|
487 |
+
self.l2_norm = Normalize(2)
|
488 |
+
self.mlps = {}
|
489 |
+
self.moving_averages = {}
|
490 |
+
self.init_type = init_type
|
491 |
+
self.init_gain = init_gain
|
492 |
+
self.gpu_ids = gpu_ids
|
493 |
+
|
494 |
+
def create_mlp(self, x):
|
495 |
+
C, H = x.shape[1], x.shape[2]
|
496 |
+
n_down = int(np.rint(np.log2(H / 32)))
|
497 |
+
mlp = []
|
498 |
+
for i in range(n_down):
|
499 |
+
mlp.append(nn.Conv2d(C, max(C // 2, 64), 3, stride=2))
|
500 |
+
mlp.append(nn.ReLU())
|
501 |
+
C = max(C // 2, 64)
|
502 |
+
mlp.append(nn.Conv2d(C, 64, 3))
|
503 |
+
mlp = nn.Sequential(*mlp)
|
504 |
+
init_net(mlp, self.init_type, self.init_gain, self.gpu_ids)
|
505 |
+
return mlp
|
506 |
+
|
507 |
+
def update_moving_average(self, key, x):
|
508 |
+
if key not in self.moving_averages:
|
509 |
+
self.moving_averages[key] = x.detach()
|
510 |
+
|
511 |
+
self.moving_averages[key] = self.moving_averages[key] * 0.999 + x.detach() * 0.001
|
512 |
+
|
513 |
+
def forward(self, x, use_instance_norm=False):
|
514 |
+
C, H = x.shape[1], x.shape[2]
|
515 |
+
key = '%d_%d' % (C, H)
|
516 |
+
if key not in self.mlps:
|
517 |
+
self.mlps[key] = self.create_mlp(x)
|
518 |
+
self.add_module("child_%s" % key, self.mlps[key])
|
519 |
+
mlp = self.mlps[key]
|
520 |
+
x = mlp(x)
|
521 |
+
self.update_moving_average(key, x)
|
522 |
+
x = x - self.moving_averages[key]
|
523 |
+
if use_instance_norm:
|
524 |
+
x = F.instance_norm(x)
|
525 |
+
return self.l2_norm(x)
|
526 |
+
|
527 |
+
|
528 |
+
class PatchSampleF(nn.Module):
|
529 |
+
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[], opt=None):
|
530 |
+
# potential issues: currently, we use the same patch_ids for multiple images in the batch
|
531 |
+
super(PatchSampleF, self).__init__()
|
532 |
+
self.l2norm = Normalize(2)
|
533 |
+
self.use_mlp = use_mlp
|
534 |
+
self.nc = nc # hard-coded
|
535 |
+
self.mlp_init = False
|
536 |
+
self.init_type = init_type
|
537 |
+
self.init_gain = init_gain
|
538 |
+
self.gpu_ids = gpu_ids
|
539 |
+
self.opt = opt
|
540 |
+
|
541 |
+
def create_mlp(self, feats):
|
542 |
+
for mlp_id, feat in enumerate(feats):
|
543 |
+
input_nc = feat.shape[1]
|
544 |
+
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
|
545 |
+
if len(self.gpu_ids) > 0:
|
546 |
+
mlp.cuda()
|
547 |
+
setattr(self, 'mlp_%d' % mlp_id, mlp)
|
548 |
+
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
|
549 |
+
self.mlp_init = True
|
550 |
+
|
551 |
+
def forward(self, feats, num_patches=64, patch_ids=None):
|
552 |
+
return_ids = []
|
553 |
+
return_feats = []
|
554 |
+
if self.use_mlp and not self.mlp_init:
|
555 |
+
self.create_mlp(feats)
|
556 |
+
for feat_id, feat in enumerate(feats):
|
557 |
+
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
|
558 |
+
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
|
559 |
+
if num_patches > 0:
|
560 |
+
if patch_ids is not None:
|
561 |
+
patch_id = patch_ids[feat_id]
|
562 |
+
else:
|
563 |
+
# torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83
|
564 |
+
#patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
|
565 |
+
patch_id = np.random.permutation(feat_reshape.shape[1])
|
566 |
+
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
|
567 |
+
patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)
|
568 |
+
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1)
|
569 |
+
else:
|
570 |
+
x_sample = feat_reshape.flatten(0, 1)
|
571 |
+
patch_id = []
|
572 |
+
if self.use_mlp:
|
573 |
+
mlp = getattr(self, 'mlp_%d' % feat_id)
|
574 |
+
x_sample = mlp(x_sample)
|
575 |
+
return_ids.append(patch_id)
|
576 |
+
x_sample = self.l2norm(x_sample)
|
577 |
+
if num_patches == 0:
|
578 |
+
x_sample = x_sample.reshape([B, H, W, x_sample.shape[-1]]).permute(0, 3, 1, 2)
|
579 |
+
return_feats.append(x_sample)
|
580 |
+
return return_feats, return_ids
|
581 |
+
|
582 |
+
|
583 |
+
class G_Resnet(nn.Module):
|
584 |
+
def __init__(self, input_nc, output_nc, nz, num_downs, n_res, ngf=64,
|
585 |
+
norm=None, nl_layer=None):
|
586 |
+
super(G_Resnet, self).__init__()
|
587 |
+
n_downsample = num_downs
|
588 |
+
pad_type = 'reflect'
|
589 |
+
self.enc_content = ContentEncoder(n_downsample, n_res, input_nc, ngf, norm, nl_layer, pad_type=pad_type)
|
590 |
+
if nz == 0:
|
591 |
+
self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
|
592 |
+
else:
|
593 |
+
self.dec = Decoder_all(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
|
594 |
+
|
595 |
+
def decode(self, content, style=None):
|
596 |
+
return self.dec(content, style)
|
597 |
+
|
598 |
+
def forward(self, image, style=None, nce_layers=[], encode_only=False):
|
599 |
+
content, feats = self.enc_content(image, nce_layers=nce_layers, encode_only=encode_only)
|
600 |
+
if encode_only:
|
601 |
+
return feats
|
602 |
+
else:
|
603 |
+
images_recon = self.decode(content, style)
|
604 |
+
if len(nce_layers) > 0:
|
605 |
+
return images_recon, feats
|
606 |
+
else:
|
607 |
+
return images_recon
|
608 |
+
|
609 |
+
##################################################################################
|
610 |
+
# Encoder and Decoders
|
611 |
+
##################################################################################
|
612 |
+
|
613 |
+
|
614 |
+
class E_adaIN(nn.Module):
|
615 |
+
def __init__(self, input_nc, output_nc=1, nef=64, n_layers=4,
|
616 |
+
norm=None, nl_layer=None, vae=False):
|
617 |
+
# style encoder
|
618 |
+
super(E_adaIN, self).__init__()
|
619 |
+
self.enc_style = StyleEncoder(n_layers, input_nc, nef, output_nc, norm='none', activ='relu', vae=vae)
|
620 |
+
|
621 |
+
def forward(self, image):
|
622 |
+
style = self.enc_style(image)
|
623 |
+
return style
|
624 |
+
|
625 |
+
|
626 |
+
class StyleEncoder(nn.Module):
|
627 |
+
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, vae=False):
|
628 |
+
super(StyleEncoder, self).__init__()
|
629 |
+
self.vae = vae
|
630 |
+
self.model = []
|
631 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
|
632 |
+
for i in range(2):
|
633 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
634 |
+
dim *= 2
|
635 |
+
for i in range(n_downsample - 2):
|
636 |
+
self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
637 |
+
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
|
638 |
+
if self.vae:
|
639 |
+
self.fc_mean = nn.Linear(dim, style_dim) # , 1, 1, 0)
|
640 |
+
self.fc_var = nn.Linear(dim, style_dim) # , 1, 1, 0)
|
641 |
+
else:
|
642 |
+
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
|
643 |
+
|
644 |
+
self.model = nn.Sequential(*self.model)
|
645 |
+
self.output_dim = dim
|
646 |
+
|
647 |
+
def forward(self, x):
|
648 |
+
if self.vae:
|
649 |
+
output = self.model(x)
|
650 |
+
output = output.view(x.size(0), -1)
|
651 |
+
output_mean = self.fc_mean(output)
|
652 |
+
output_var = self.fc_var(output)
|
653 |
+
return output_mean, output_var
|
654 |
+
else:
|
655 |
+
return self.model(x).view(x.size(0), -1)
|
656 |
+
|
657 |
+
|
658 |
+
class ContentEncoder(nn.Module):
|
659 |
+
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type='zero'):
|
660 |
+
super(ContentEncoder, self).__init__()
|
661 |
+
self.model = []
|
662 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
|
663 |
+
# downsampling blocks
|
664 |
+
for i in range(n_downsample):
|
665 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
666 |
+
dim *= 2
|
667 |
+
# residual blocks
|
668 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
669 |
+
self.model = nn.Sequential(*self.model)
|
670 |
+
self.output_dim = dim
|
671 |
+
|
672 |
+
def forward(self, x, nce_layers=[], encode_only=False):
|
673 |
+
if len(nce_layers) > 0:
|
674 |
+
feat = x
|
675 |
+
feats = []
|
676 |
+
for layer_id, layer in enumerate(self.model):
|
677 |
+
feat = layer(feat)
|
678 |
+
if layer_id in nce_layers:
|
679 |
+
feats.append(feat)
|
680 |
+
if layer_id == nce_layers[-1] and encode_only:
|
681 |
+
return None, feats
|
682 |
+
return feat, feats
|
683 |
+
else:
|
684 |
+
return self.model(x), None
|
685 |
+
|
686 |
+
|
687 |
+
|
688 |
+
class Decoder_all(nn.Module):
|
689 |
+
def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
|
690 |
+
super(Decoder_all, self).__init__()
|
691 |
+
# AdaIN residual blocks
|
692 |
+
self.resnet_block = ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)
|
693 |
+
self.n_blocks = 0
|
694 |
+
# upsampling blocks
|
695 |
+
for i in range(n_upsample):
|
696 |
+
block = [Upsample2(scale_factor=2), Conv2dBlock(dim + nz, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
|
697 |
+
setattr(self, 'block_{:d}'.format(self.n_blocks), nn.Sequential(*block))
|
698 |
+
self.n_blocks += 1
|
699 |
+
dim //= 2
|
700 |
+
# use reflection padding in the last conv layer
|
701 |
+
setattr(self, 'block_{:d}'.format(self.n_blocks), Conv2dBlock(dim + nz, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect'))
|
702 |
+
self.n_blocks += 1
|
703 |
+
|
704 |
+
def forward(self, x, y=None):
|
705 |
+
if y is not None:
|
706 |
+
output = self.resnet_block(cat_feature(x, y))
|
707 |
+
for n in range(self.n_blocks):
|
708 |
+
block = getattr(self, 'block_{:d}'.format(n))
|
709 |
+
if n > 0:
|
710 |
+
output = block(cat_feature(output, y))
|
711 |
+
else:
|
712 |
+
output = block(output)
|
713 |
+
return output
|
714 |
+
|
715 |
+
|
716 |
+
class Decoder(nn.Module):
|
717 |
+
def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
|
718 |
+
super(Decoder, self).__init__()
|
719 |
+
|
720 |
+
self.model = []
|
721 |
+
# AdaIN residual blocks
|
722 |
+
self.model += [ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)]
|
723 |
+
# upsampling blocks
|
724 |
+
for i in range(n_upsample):
|
725 |
+
if i == 0:
|
726 |
+
input_dim = dim + nz
|
727 |
+
else:
|
728 |
+
input_dim = dim
|
729 |
+
self.model += [Upsample2(scale_factor=2), Conv2dBlock(input_dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
|
730 |
+
dim //= 2
|
731 |
+
# use reflection padding in the last conv layer
|
732 |
+
self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')]
|
733 |
+
self.model = nn.Sequential(*self.model)
|
734 |
+
|
735 |
+
def forward(self, x, y=None):
|
736 |
+
if y is not None:
|
737 |
+
return self.model(cat_feature(x, y))
|
738 |
+
else:
|
739 |
+
return self.model(x)
|
740 |
+
|
741 |
+
##################################################################################
|
742 |
+
# Sequential Models
|
743 |
+
##################################################################################
|
744 |
+
|
745 |
+
|
746 |
+
class ResBlocks(nn.Module):
|
747 |
+
def __init__(self, num_blocks, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
|
748 |
+
super(ResBlocks, self).__init__()
|
749 |
+
self.model = []
|
750 |
+
for i in range(num_blocks):
|
751 |
+
self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, nz=nz)]
|
752 |
+
self.model = nn.Sequential(*self.model)
|
753 |
+
|
754 |
+
def forward(self, x):
|
755 |
+
return self.model(x)
|
756 |
+
|
757 |
+
|
758 |
+
##################################################################################
|
759 |
+
# Basic Blocks
|
760 |
+
##################################################################################
|
761 |
+
def cat_feature(x, y):
|
762 |
+
y_expand = y.view(y.size(0), y.size(1), 1, 1).expand(
|
763 |
+
y.size(0), y.size(1), x.size(2), x.size(3))
|
764 |
+
x_cat = torch.cat([x, y_expand], 1)
|
765 |
+
return x_cat
|
766 |
+
|
767 |
+
|
768 |
+
class ResBlock(nn.Module):
|
769 |
+
def __init__(self, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
|
770 |
+
super(ResBlock, self).__init__()
|
771 |
+
|
772 |
+
model = []
|
773 |
+
model += [Conv2dBlock(dim + nz, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
|
774 |
+
model += [Conv2dBlock(dim, dim + nz, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
|
775 |
+
self.model = nn.Sequential(*model)
|
776 |
+
|
777 |
+
def forward(self, x):
|
778 |
+
residual = x
|
779 |
+
out = self.model(x)
|
780 |
+
out += residual
|
781 |
+
return out
|
782 |
+
|
783 |
+
|
784 |
+
class Conv2dBlock(nn.Module):
|
785 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride,
|
786 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
787 |
+
super(Conv2dBlock, self).__init__()
|
788 |
+
self.use_bias = True
|
789 |
+
# initialize padding
|
790 |
+
if pad_type == 'reflect':
|
791 |
+
self.pad = nn.ReflectionPad2d(padding)
|
792 |
+
elif pad_type == 'zero':
|
793 |
+
self.pad = nn.ZeroPad2d(padding)
|
794 |
+
else:
|
795 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
796 |
+
|
797 |
+
# initialize normalization
|
798 |
+
norm_dim = output_dim
|
799 |
+
if norm == 'batch':
|
800 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
801 |
+
elif norm == 'inst':
|
802 |
+
self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False)
|
803 |
+
elif norm == 'ln':
|
804 |
+
self.norm = LayerNorm(norm_dim)
|
805 |
+
elif norm == 'none':
|
806 |
+
self.norm = None
|
807 |
+
else:
|
808 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
809 |
+
|
810 |
+
# initialize activation
|
811 |
+
if activation == 'relu':
|
812 |
+
self.activation = nn.ReLU(inplace=True)
|
813 |
+
elif activation == 'lrelu':
|
814 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
815 |
+
elif activation == 'prelu':
|
816 |
+
self.activation = nn.PReLU()
|
817 |
+
elif activation == 'selu':
|
818 |
+
self.activation = nn.SELU(inplace=True)
|
819 |
+
elif activation == 'tanh':
|
820 |
+
self.activation = nn.Tanh()
|
821 |
+
elif activation == 'none':
|
822 |
+
self.activation = None
|
823 |
+
else:
|
824 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
825 |
+
|
826 |
+
# initialize convolution
|
827 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
828 |
+
|
829 |
+
def forward(self, x):
|
830 |
+
x = self.conv(self.pad(x))
|
831 |
+
if self.norm:
|
832 |
+
x = self.norm(x)
|
833 |
+
if self.activation:
|
834 |
+
x = self.activation(x)
|
835 |
+
return x
|
836 |
+
|
837 |
+
|
838 |
+
class LinearBlock(nn.Module):
|
839 |
+
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
840 |
+
super(LinearBlock, self).__init__()
|
841 |
+
use_bias = True
|
842 |
+
# initialize fully connected layer
|
843 |
+
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
844 |
+
|
845 |
+
# initialize normalization
|
846 |
+
norm_dim = output_dim
|
847 |
+
if norm == 'batch':
|
848 |
+
self.norm = nn.BatchNorm1d(norm_dim)
|
849 |
+
elif norm == 'inst':
|
850 |
+
self.norm = nn.InstanceNorm1d(norm_dim)
|
851 |
+
elif norm == 'ln':
|
852 |
+
self.norm = LayerNorm(norm_dim)
|
853 |
+
elif norm == 'none':
|
854 |
+
self.norm = None
|
855 |
+
else:
|
856 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
857 |
+
|
858 |
+
# initialize activation
|
859 |
+
if activation == 'relu':
|
860 |
+
self.activation = nn.ReLU(inplace=True)
|
861 |
+
elif activation == 'lrelu':
|
862 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
863 |
+
elif activation == 'prelu':
|
864 |
+
self.activation = nn.PReLU()
|
865 |
+
elif activation == 'selu':
|
866 |
+
self.activation = nn.SELU(inplace=True)
|
867 |
+
elif activation == 'tanh':
|
868 |
+
self.activation = nn.Tanh()
|
869 |
+
elif activation == 'none':
|
870 |
+
self.activation = None
|
871 |
+
else:
|
872 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
873 |
+
|
874 |
+
def forward(self, x):
|
875 |
+
out = self.fc(x)
|
876 |
+
if self.norm:
|
877 |
+
out = self.norm(out)
|
878 |
+
if self.activation:
|
879 |
+
out = self.activation(out)
|
880 |
+
return out
|
881 |
+
|
882 |
+
##################################################################################
|
883 |
+
# Normalization layers
|
884 |
+
##################################################################################
|
885 |
+
|
886 |
+
|
887 |
+
class LayerNorm(nn.Module):
|
888 |
+
def __init__(self, num_features, eps=1e-5, affine=True):
|
889 |
+
super(LayerNorm, self).__init__()
|
890 |
+
self.num_features = num_features
|
891 |
+
self.affine = affine
|
892 |
+
self.eps = eps
|
893 |
+
|
894 |
+
if self.affine:
|
895 |
+
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
|
896 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
897 |
+
|
898 |
+
def forward(self, x):
|
899 |
+
shape = [-1] + [1] * (x.dim() - 1)
|
900 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
901 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
902 |
+
x = (x - mean) / (std + self.eps)
|
903 |
+
|
904 |
+
if self.affine:
|
905 |
+
shape = [1, -1] + [1] * (x.dim() - 2)
|
906 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
907 |
+
return x
|
908 |
+
|
909 |
+
|
910 |
+
class ResnetGenerator(nn.Module):
|
911 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
912 |
+
|
913 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
914 |
+
"""
|
915 |
+
|
916 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
|
917 |
+
"""Construct a Resnet-based generator
|
918 |
+
|
919 |
+
Parameters:
|
920 |
+
input_nc (int) -- the number of channels in input images
|
921 |
+
output_nc (int) -- the number of channels in output images
|
922 |
+
ngf (int) -- the number of filters in the last conv layer
|
923 |
+
norm_layer -- normalization layer
|
924 |
+
use_dropout (bool) -- if use dropout layers
|
925 |
+
n_blocks (int) -- the number of ResNet blocks
|
926 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
927 |
+
"""
|
928 |
+
assert(n_blocks >= 0)
|
929 |
+
super(ResnetGenerator, self).__init__()
|
930 |
+
self.opt = opt
|
931 |
+
if type(norm_layer) == functools.partial:
|
932 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
933 |
+
else:
|
934 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
935 |
+
|
936 |
+
if opt.weight_norm == 'spectral':
|
937 |
+
weight_norm = nn.utils.spectral_norm
|
938 |
+
else:
|
939 |
+
def weight_norm(x): return x
|
940 |
+
|
941 |
+
model = [nn.ReflectionPad2d(3),
|
942 |
+
weight_norm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias)),
|
943 |
+
norm_layer(ngf),
|
944 |
+
nn.ReLU(True)]
|
945 |
+
|
946 |
+
n_downsampling = getattr(opt, 'n_downsampling', 2)
|
947 |
+
for i in range(n_downsampling): # add downsampling layers
|
948 |
+
mult = 2 ** i
|
949 |
+
if(no_antialias):
|
950 |
+
model += [weight_norm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)),
|
951 |
+
norm_layer(ngf * mult * 2),
|
952 |
+
nn.ReLU(True)]
|
953 |
+
else:
|
954 |
+
model += [weight_norm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)),
|
955 |
+
norm_layer(ngf * mult * 2),
|
956 |
+
nn.ReLU(True),
|
957 |
+
Downsample(ngf * mult * 2)]
|
958 |
+
|
959 |
+
mult = 2 ** n_downsampling
|
960 |
+
for i in range(n_blocks): # add ResNet blocks
|
961 |
+
extra = None
|
962 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, opt=opt)]
|
963 |
+
|
964 |
+
for i in range(n_downsampling): # add upsampling layers
|
965 |
+
mult = 2 ** (n_downsampling - i)
|
966 |
+
if no_antialias_up:
|
967 |
+
model += [weight_norm(nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
968 |
+
kernel_size=3, stride=2,
|
969 |
+
padding=1, output_padding=1,
|
970 |
+
bias=use_bias)),
|
971 |
+
norm_layer(int(ngf * mult / 2)),
|
972 |
+
nn.ReLU(True)]
|
973 |
+
else:
|
974 |
+
model += [Upsample(ngf * mult),
|
975 |
+
weight_norm(nn.Conv2d(ngf * mult, int(ngf * mult / 2),
|
976 |
+
kernel_size=3, stride=1,
|
977 |
+
padding=1, # output_padding=1,
|
978 |
+
bias=use_bias)),
|
979 |
+
norm_layer(int(ngf * mult / 2)),
|
980 |
+
nn.ReLU(True)]
|
981 |
+
model += [nn.ReflectionPad2d(3)]
|
982 |
+
model += [weight_norm(nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0))]
|
983 |
+
model += [nn.Tanh()]
|
984 |
+
|
985 |
+
self.model = nn.Sequential(*model)
|
986 |
+
|
987 |
+
def forward(self, input, layers=[], encode_only=False):
|
988 |
+
if -1 in layers:
|
989 |
+
layers.append(len(self.model))
|
990 |
+
if len(layers) > 0:
|
991 |
+
feat = input
|
992 |
+
feats = []
|
993 |
+
for layer_id, layer in enumerate(self.model):
|
994 |
+
# print(layer_id, layer)
|
995 |
+
feat = layer(feat)
|
996 |
+
if layer_id in layers:
|
997 |
+
# print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
998 |
+
feats.append(feat)
|
999 |
+
else:
|
1000 |
+
# print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
1001 |
+
pass
|
1002 |
+
if layer_id == layers[-1] and encode_only:
|
1003 |
+
# print('encoder only return features')
|
1004 |
+
return feats # return intermediate features alone; stop in the last layers
|
1005 |
+
|
1006 |
+
return feat, feats # return both output and intermediate features
|
1007 |
+
else:
|
1008 |
+
"""Standard forward"""
|
1009 |
+
fake = self.model(input)
|
1010 |
+
return fake
|
1011 |
+
|
1012 |
+
|
1013 |
+
class ResnetDecoder(nn.Module):
|
1014 |
+
"""Resnet-based decoder that consists of a few Resnet blocks + a few upsampling operations.
|
1015 |
+
"""
|
1016 |
+
|
1017 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
|
1018 |
+
"""Construct a Resnet-based decoder
|
1019 |
+
|
1020 |
+
Parameters:
|
1021 |
+
input_nc (int) -- the number of channels in input images
|
1022 |
+
output_nc (int) -- the number of channels in output images
|
1023 |
+
ngf (int) -- the number of filters in the last conv layer
|
1024 |
+
norm_layer -- normalization layer
|
1025 |
+
use_dropout (bool) -- if use dropout layers
|
1026 |
+
n_blocks (int) -- the number of ResNet blocks
|
1027 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
1028 |
+
"""
|
1029 |
+
assert(n_blocks >= 0)
|
1030 |
+
super(ResnetDecoder, self).__init__()
|
1031 |
+
if type(norm_layer) == functools.partial:
|
1032 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1033 |
+
else:
|
1034 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1035 |
+
model = []
|
1036 |
+
n_downsampling = 2
|
1037 |
+
mult = 2 ** n_downsampling
|
1038 |
+
for i in range(n_blocks): # add ResNet blocks
|
1039 |
+
|
1040 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
1041 |
+
|
1042 |
+
for i in range(n_downsampling): # add upsampling layers
|
1043 |
+
mult = 2 ** (n_downsampling - i)
|
1044 |
+
if(no_antialias):
|
1045 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
1046 |
+
kernel_size=3, stride=2,
|
1047 |
+
padding=1, output_padding=1,
|
1048 |
+
bias=use_bias),
|
1049 |
+
norm_layer(int(ngf * mult / 2)),
|
1050 |
+
nn.ReLU(True)]
|
1051 |
+
else:
|
1052 |
+
model += [Upsample(ngf * mult),
|
1053 |
+
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
|
1054 |
+
kernel_size=3, stride=1,
|
1055 |
+
padding=1,
|
1056 |
+
bias=use_bias),
|
1057 |
+
norm_layer(int(ngf * mult / 2)),
|
1058 |
+
nn.ReLU(True)]
|
1059 |
+
model += [nn.ReflectionPad2d(3)]
|
1060 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
1061 |
+
model += [nn.Tanh()]
|
1062 |
+
|
1063 |
+
self.model = nn.Sequential(*model)
|
1064 |
+
|
1065 |
+
def forward(self, input):
|
1066 |
+
"""Standard forward"""
|
1067 |
+
return self.model(input)
|
1068 |
+
|
1069 |
+
|
1070 |
+
class ResnetEncoder(nn.Module):
|
1071 |
+
"""Resnet-based encoder that consists of a few downsampling + several Resnet blocks
|
1072 |
+
"""
|
1073 |
+
|
1074 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
|
1075 |
+
"""Construct a Resnet-based encoder
|
1076 |
+
|
1077 |
+
Parameters:
|
1078 |
+
input_nc (int) -- the number of channels in input images
|
1079 |
+
output_nc (int) -- the number of channels in output images
|
1080 |
+
ngf (int) -- the number of filters in the last conv layer
|
1081 |
+
norm_layer -- normalization layer
|
1082 |
+
use_dropout (bool) -- if use dropout layers
|
1083 |
+
n_blocks (int) -- the number of ResNet blocks
|
1084 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
1085 |
+
"""
|
1086 |
+
assert(n_blocks >= 0)
|
1087 |
+
super(ResnetEncoder, self).__init__()
|
1088 |
+
if type(norm_layer) == functools.partial:
|
1089 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1090 |
+
else:
|
1091 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1092 |
+
|
1093 |
+
model = [nn.ReflectionPad2d(3),
|
1094 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
1095 |
+
norm_layer(ngf),
|
1096 |
+
nn.ReLU(True)]
|
1097 |
+
|
1098 |
+
n_downsampling = 2
|
1099 |
+
for i in range(n_downsampling): # add downsampling layers
|
1100 |
+
mult = 2 ** i
|
1101 |
+
if(no_antialias):
|
1102 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
1103 |
+
norm_layer(ngf * mult * 2),
|
1104 |
+
nn.ReLU(True)]
|
1105 |
+
else:
|
1106 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
1107 |
+
norm_layer(ngf * mult * 2),
|
1108 |
+
nn.ReLU(True),
|
1109 |
+
Downsample(ngf * mult * 2)]
|
1110 |
+
|
1111 |
+
mult = 2 ** n_downsampling
|
1112 |
+
for i in range(n_blocks): # add ResNet blocks
|
1113 |
+
|
1114 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
1115 |
+
|
1116 |
+
self.model = nn.Sequential(*model)
|
1117 |
+
|
1118 |
+
def forward(self, input):
|
1119 |
+
"""Standard forward"""
|
1120 |
+
return self.model(input)
|
1121 |
+
|
1122 |
+
|
1123 |
+
class ResnetBlock(nn.Module):
|
1124 |
+
"""Define a Resnet block"""
|
1125 |
+
|
1126 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, opt=None):
|
1127 |
+
"""Initialize the Resnet block
|
1128 |
+
|
1129 |
+
A resnet block is a conv block with skip connections
|
1130 |
+
We construct a conv block with build_conv_block function,
|
1131 |
+
and implement skip connections in <forward> function.
|
1132 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
1133 |
+
"""
|
1134 |
+
super(ResnetBlock, self).__init__()
|
1135 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, opt)
|
1136 |
+
|
1137 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, opt=None):
|
1138 |
+
"""Construct a convolutional block.
|
1139 |
+
|
1140 |
+
Parameters:
|
1141 |
+
dim (int) -- the number of channels in the conv layer.
|
1142 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
1143 |
+
norm_layer -- normalization layer
|
1144 |
+
use_dropout (bool) -- if use dropout layers.
|
1145 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
1146 |
+
|
1147 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
1148 |
+
"""
|
1149 |
+
conv_block = []
|
1150 |
+
p = 0
|
1151 |
+
if padding_type == 'reflect':
|
1152 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
1153 |
+
elif padding_type == 'replicate':
|
1154 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
1155 |
+
elif padding_type == 'zero':
|
1156 |
+
p = 1
|
1157 |
+
else:
|
1158 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
1159 |
+
|
1160 |
+
if opt.weight_norm == 'spectral':
|
1161 |
+
weight_norm = nn.utils.spectral_norm
|
1162 |
+
else:
|
1163 |
+
def weight_norm(x): return x
|
1164 |
+
|
1165 |
+
conv_block += [weight_norm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)), norm_layer(dim), nn.ReLU(True)]
|
1166 |
+
if use_dropout:
|
1167 |
+
conv_block += [nn.Dropout(0.5)]
|
1168 |
+
|
1169 |
+
p = 0
|
1170 |
+
if padding_type == 'reflect':
|
1171 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
1172 |
+
elif padding_type == 'replicate':
|
1173 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
1174 |
+
elif padding_type == 'zero':
|
1175 |
+
p = 1
|
1176 |
+
else:
|
1177 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
1178 |
+
conv_block += [weight_norm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)), norm_layer(dim)]
|
1179 |
+
|
1180 |
+
return nn.Sequential(*conv_block)
|
1181 |
+
|
1182 |
+
def forward(self, x):
|
1183 |
+
"""Forward function (with skip connections)"""
|
1184 |
+
out = x + self.conv_block(x) # add skip connections
|
1185 |
+
return out
|
1186 |
+
|
1187 |
+
|
1188 |
+
class UnetGenerator(nn.Module):
|
1189 |
+
"""Create a Unet-based generator"""
|
1190 |
+
|
1191 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
1192 |
+
"""Construct a Unet generator
|
1193 |
+
Parameters:
|
1194 |
+
input_nc (int) -- the number of channels in input images
|
1195 |
+
output_nc (int) -- the number of channels in output images
|
1196 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
1197 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
1198 |
+
ngf (int) -- the number of filters in the last conv layer
|
1199 |
+
norm_layer -- normalization layer
|
1200 |
+
|
1201 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
1202 |
+
It is a recursive process.
|
1203 |
+
"""
|
1204 |
+
super(UnetGenerator, self).__init__()
|
1205 |
+
# construct unet structure
|
1206 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
1207 |
+
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
1208 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
1209 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
1210 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1211 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1212 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1213 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
1214 |
+
|
1215 |
+
def forward(self, input):
|
1216 |
+
"""Standard forward"""
|
1217 |
+
return self.model(input)
|
1218 |
+
|
1219 |
+
|
1220 |
+
class UnetSkipConnectionBlock(nn.Module):
|
1221 |
+
"""Defines the Unet submodule with skip connection.
|
1222 |
+
X -------------------identity----------------------
|
1223 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
1224 |
+
"""
|
1225 |
+
|
1226 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
1227 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
1228 |
+
"""Construct a Unet submodule with skip connections.
|
1229 |
+
|
1230 |
+
Parameters:
|
1231 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
1232 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
1233 |
+
input_nc (int) -- the number of channels in input images/features
|
1234 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
1235 |
+
outermost (bool) -- if this module is the outermost module
|
1236 |
+
innermost (bool) -- if this module is the innermost module
|
1237 |
+
norm_layer -- normalization layer
|
1238 |
+
use_dropout (bool) -- if use dropout layers.
|
1239 |
+
"""
|
1240 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
1241 |
+
self.outermost = outermost
|
1242 |
+
if type(norm_layer) == functools.partial:
|
1243 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1244 |
+
else:
|
1245 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1246 |
+
if input_nc is None:
|
1247 |
+
input_nc = outer_nc
|
1248 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
1249 |
+
stride=2, padding=1, bias=use_bias)
|
1250 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
1251 |
+
downnorm = norm_layer(inner_nc)
|
1252 |
+
uprelu = nn.ReLU(True)
|
1253 |
+
upnorm = norm_layer(outer_nc)
|
1254 |
+
|
1255 |
+
if outermost:
|
1256 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1257 |
+
kernel_size=4, stride=2,
|
1258 |
+
padding=1)
|
1259 |
+
down = [downconv]
|
1260 |
+
up = [uprelu, upconv, nn.Tanh()]
|
1261 |
+
model = down + [submodule] + up
|
1262 |
+
elif innermost:
|
1263 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
1264 |
+
kernel_size=4, stride=2,
|
1265 |
+
padding=1, bias=use_bias)
|
1266 |
+
down = [downrelu, downconv]
|
1267 |
+
up = [uprelu, upconv, upnorm]
|
1268 |
+
model = down + up
|
1269 |
+
else:
|
1270 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1271 |
+
kernel_size=4, stride=2,
|
1272 |
+
padding=1, bias=use_bias)
|
1273 |
+
down = [downrelu, downconv, downnorm]
|
1274 |
+
up = [uprelu, upconv, upnorm]
|
1275 |
+
|
1276 |
+
if use_dropout:
|
1277 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
1278 |
+
else:
|
1279 |
+
model = down + [submodule] + up
|
1280 |
+
|
1281 |
+
self.model = nn.Sequential(*model)
|
1282 |
+
|
1283 |
+
def forward(self, x):
|
1284 |
+
if self.outermost:
|
1285 |
+
return self.model(x)
|
1286 |
+
else: # add skip connections
|
1287 |
+
return torch.cat([x, self.model(x)], 1)
|
1288 |
+
|
1289 |
+
|
1290 |
+
class NLayerDiscriminator(nn.Module):
|
1291 |
+
"""Defines a PatchGAN discriminator"""
|
1292 |
+
|
1293 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False, opt=None):
|
1294 |
+
"""Construct a PatchGAN discriminator
|
1295 |
+
|
1296 |
+
Parameters:
|
1297 |
+
input_nc (int) -- the number of channels in input images
|
1298 |
+
ndf (int) -- the number of filters in the last conv layer
|
1299 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
1300 |
+
norm_layer -- normalization layer
|
1301 |
+
"""
|
1302 |
+
super(NLayerDiscriminator, self).__init__()
|
1303 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
1304 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1305 |
+
else:
|
1306 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1307 |
+
|
1308 |
+
if opt.weight_norm == 'spectral':
|
1309 |
+
weight_norm = nn.utils.spectral_norm
|
1310 |
+
else:
|
1311 |
+
def weight_norm(x): return x
|
1312 |
+
|
1313 |
+
kw = 4
|
1314 |
+
padw = 1
|
1315 |
+
if(no_antialias):
|
1316 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
1317 |
+
else:
|
1318 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
|
1319 |
+
nf_mult = 1
|
1320 |
+
nf_mult_prev = 1
|
1321 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
1322 |
+
nf_mult_prev = nf_mult
|
1323 |
+
nf_mult = min(2 ** n, 8)
|
1324 |
+
if(no_antialias):
|
1325 |
+
sequence += [
|
1326 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
1327 |
+
norm_layer(ndf * nf_mult),
|
1328 |
+
nn.LeakyReLU(0.2, True)
|
1329 |
+
]
|
1330 |
+
else:
|
1331 |
+
sequence += [
|
1332 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
1333 |
+
norm_layer(ndf * nf_mult),
|
1334 |
+
nn.LeakyReLU(0.2, True),
|
1335 |
+
Downsample(ndf * nf_mult)
|
1336 |
+
]
|
1337 |
+
|
1338 |
+
nf_mult_prev = nf_mult
|
1339 |
+
nf_mult = min(2 ** n_layers, 8)
|
1340 |
+
sequence += [
|
1341 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
1342 |
+
norm_layer(ndf * nf_mult),
|
1343 |
+
nn.LeakyReLU(0.2, True)
|
1344 |
+
]
|
1345 |
+
|
1346 |
+
for i, layer in enumerate(sequence):
|
1347 |
+
if isinstance(layer, nn.Conv2d):
|
1348 |
+
sequence[i] = weight_norm(layer)
|
1349 |
+
|
1350 |
+
self.enc = nn.Sequential(*sequence)
|
1351 |
+
# output 1 channel prediction map
|
1352 |
+
self.final_conv = weight_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))
|
1353 |
+
|
1354 |
+
|
1355 |
+
def forward(self, input, labels=None):
|
1356 |
+
"""Standard forward."""
|
1357 |
+
final_ft = self.enc(input)
|
1358 |
+
dout = self.final_conv(final_ft)
|
1359 |
+
return dout
|
1360 |
+
|
1361 |
+
|
1362 |
+
class PixelDiscriminator(nn.Module):
|
1363 |
+
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
1364 |
+
|
1365 |
+
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
1366 |
+
"""Construct a 1x1 PatchGAN discriminator
|
1367 |
+
|
1368 |
+
Parameters:
|
1369 |
+
input_nc (int) -- the number of channels in input images
|
1370 |
+
ndf (int) -- the number of filters in the last conv layer
|
1371 |
+
norm_layer -- normalization layer
|
1372 |
+
"""
|
1373 |
+
super(PixelDiscriminator, self).__init__()
|
1374 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
1375 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1376 |
+
else:
|
1377 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1378 |
+
|
1379 |
+
self.net = [
|
1380 |
+
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
1381 |
+
nn.LeakyReLU(0.2, True),
|
1382 |
+
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
1383 |
+
norm_layer(ndf * 2),
|
1384 |
+
nn.LeakyReLU(0.2, True),
|
1385 |
+
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
1386 |
+
|
1387 |
+
self.net = nn.Sequential(*self.net)
|
1388 |
+
|
1389 |
+
def forward(self, input):
|
1390 |
+
"""Standard forward."""
|
1391 |
+
return self.net(input)
|
1392 |
+
|
1393 |
+
|
1394 |
+
class PatchDiscriminator(NLayerDiscriminator):
|
1395 |
+
"""Defines a PatchGAN discriminator"""
|
1396 |
+
|
1397 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
|
1398 |
+
super().__init__(input_nc, ndf, 2, norm_layer, no_antialias)
|
1399 |
+
|
1400 |
+
def forward(self, input):
|
1401 |
+
B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
|
1402 |
+
size = 16
|
1403 |
+
Y = H // size
|
1404 |
+
X = W // size
|
1405 |
+
input = input.view(B, C, Y, size, X, size)
|
1406 |
+
input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
|
1407 |
+
return super().forward(input)
|
1408 |
+
|
1409 |
+
|
1410 |
+
class GroupedChannelNorm(nn.Module):
|
1411 |
+
def __init__(self, num_groups):
|
1412 |
+
super().__init__()
|
1413 |
+
self.num_groups = num_groups
|
1414 |
+
|
1415 |
+
def forward(self, x):
|
1416 |
+
shape = list(x.shape)
|
1417 |
+
new_shape = [shape[0], self.num_groups, shape[1] // self.num_groups] + shape[2:]
|
1418 |
+
x = x.view(*new_shape)
|
1419 |
+
mean = x.mean(dim=2, keepdim=True)
|
1420 |
+
std = x.std(dim=2, keepdim=True)
|
1421 |
+
x_norm = (x - mean) / (std + 1e-7)
|
1422 |
+
return x_norm.view(*shape)
|
asp/models/patchnce.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class PatchNCELoss(nn.Module):
|
7 |
+
def __init__(self, opt):
|
8 |
+
super().__init__()
|
9 |
+
self.opt = opt
|
10 |
+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
11 |
+
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
12 |
+
|
13 |
+
def forward(self, feat_q, feat_k):
|
14 |
+
num_patches = feat_q.shape[0]
|
15 |
+
dim = feat_q.shape[1]
|
16 |
+
feat_k = feat_k.detach()
|
17 |
+
|
18 |
+
# pos logit
|
19 |
+
l_pos = torch.bmm(
|
20 |
+
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
|
21 |
+
l_pos = l_pos.view(num_patches, 1)
|
22 |
+
|
23 |
+
# neg logit
|
24 |
+
|
25 |
+
# Should the negatives from the other samples of a minibatch be utilized?
|
26 |
+
# In CUT and FastCUT, we found that it's best to only include negatives
|
27 |
+
# from the same image. Therefore, we set
|
28 |
+
# --nce_includes_all_negatives_from_minibatch as False
|
29 |
+
# However, for single-image translation, the minibatch consists of
|
30 |
+
# crops from the "same" high-resolution image.
|
31 |
+
# Therefore, we will include the negatives from the entire minibatch.
|
32 |
+
if self.opt.nce_includes_all_negatives_from_minibatch:
|
33 |
+
# reshape features as if they are all negatives of minibatch of size 1.
|
34 |
+
batch_dim_for_bmm = 1
|
35 |
+
else:
|
36 |
+
batch_dim_for_bmm = self.opt.batch_size
|
37 |
+
|
38 |
+
# reshape features to batch size
|
39 |
+
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
40 |
+
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
41 |
+
npatches = feat_q.size(1)
|
42 |
+
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
43 |
+
|
44 |
+
# diagonal entries are similarity between same features, and hence meaningless.
|
45 |
+
# just fill the diagonal with very small number, which is exp(-10) and almost zero
|
46 |
+
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
47 |
+
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
48 |
+
l_neg = l_neg_curbatch.view(-1, npatches)
|
49 |
+
|
50 |
+
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
|
51 |
+
|
52 |
+
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
|
53 |
+
device=feat_q.device))
|
54 |
+
|
55 |
+
return loss
|
asp/options/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
|
asp/options/base_options.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from util import util
|
4 |
+
import torch
|
5 |
+
import models
|
6 |
+
import data
|
7 |
+
|
8 |
+
|
9 |
+
class BaseOptions():
|
10 |
+
"""This class defines options used during both training and test time.
|
11 |
+
|
12 |
+
It also implements several helper functions such as parsing, printing, and saving the options.
|
13 |
+
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, cmd_line=None):
|
17 |
+
"""Reset the class; indicates the class hasn't been initialized"""
|
18 |
+
self.initialized = False
|
19 |
+
self.cmd_line = None
|
20 |
+
if cmd_line is not None:
|
21 |
+
self.cmd_line = cmd_line.split()
|
22 |
+
|
23 |
+
def initialize(self, parser):
|
24 |
+
"""Define the common options that are used in both training and test."""
|
25 |
+
# basic parameters
|
26 |
+
parser.add_argument('--dataroot', default='placeholder', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
|
27 |
+
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
|
28 |
+
parser.add_argument('--easy_label', type=str, default='experiment_name', help='Interpretable name')
|
29 |
+
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
30 |
+
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
31 |
+
# model parameters
|
32 |
+
parser.add_argument('--model', type=str, default='cut', help='chooses which model to use.')
|
33 |
+
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
|
34 |
+
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
|
35 |
+
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
36 |
+
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
37 |
+
parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2', 'multi_d'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
38 |
+
parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks', 'resnet_6blocks', 'resnet_4blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'fdlresnet', 'fdlunet'], help='specify generator architecture')
|
39 |
+
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
40 |
+
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
41 |
+
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
42 |
+
parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization')
|
43 |
+
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
44 |
+
parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True,
|
45 |
+
help='no dropout for the generator')
|
46 |
+
parser.add_argument('--no_antialias', action='store_true', help='if specified, use stride=2 convs instead of antialiased-downsampling (sad)')
|
47 |
+
parser.add_argument('--no_antialias_up', action='store_true', help='if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]')
|
48 |
+
# dataset parameters
|
49 |
+
parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
|
50 |
+
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
|
51 |
+
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
52 |
+
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
|
53 |
+
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
|
54 |
+
parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
|
55 |
+
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
|
56 |
+
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
57 |
+
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
|
58 |
+
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
|
59 |
+
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
|
60 |
+
parser.add_argument('--random_scale_max', type=float, default=3.0,
|
61 |
+
help='(used for single image translation) Randomly scale the image by the specified factor as data augmentation.')
|
62 |
+
# additional parameters
|
63 |
+
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
64 |
+
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
65 |
+
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
|
66 |
+
|
67 |
+
# parameters related to StyleGAN2-based networks
|
68 |
+
parser.add_argument('--stylegan2_G_num_downsampling',
|
69 |
+
default=1, type=int,
|
70 |
+
help='Number of downsampling layers used by StyleGAN2Generator')
|
71 |
+
|
72 |
+
# FDL:
|
73 |
+
parser.add_argument('--weight_norm', type=str, default='none', choices=['none', 'spectral'], help='chooses which weight norm layer to use.')
|
74 |
+
|
75 |
+
self.initialized = True
|
76 |
+
return parser
|
77 |
+
|
78 |
+
def gather_options(self):
|
79 |
+
"""Initialize our parser with basic options(only once).
|
80 |
+
Add additional model-specific and dataset-specific options.
|
81 |
+
These options are defined in the <modify_commandline_options> function
|
82 |
+
in model and dataset classes.
|
83 |
+
"""
|
84 |
+
if not self.initialized: # check if it has been initialized
|
85 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
86 |
+
parser = self.initialize(parser)
|
87 |
+
|
88 |
+
# get the basic options
|
89 |
+
if self.cmd_line is None:
|
90 |
+
opt, _ = parser.parse_known_args()
|
91 |
+
else:
|
92 |
+
opt, _ = parser.parse_known_args(self.cmd_line)
|
93 |
+
|
94 |
+
# modify model-related parser options
|
95 |
+
model_name = opt.model
|
96 |
+
model_option_setter = models.get_option_setter(model_name)
|
97 |
+
parser = model_option_setter(parser, self.isTrain)
|
98 |
+
if self.cmd_line is None:
|
99 |
+
opt, _ = parser.parse_known_args() # parse again with new defaults
|
100 |
+
else:
|
101 |
+
opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
|
102 |
+
|
103 |
+
# modify dataset-related parser options
|
104 |
+
dataset_name = opt.dataset_mode
|
105 |
+
dataset_option_setter = data.get_option_setter(dataset_name)
|
106 |
+
parser = dataset_option_setter(parser, self.isTrain)
|
107 |
+
|
108 |
+
# save and return the parser
|
109 |
+
self.parser = parser
|
110 |
+
if self.cmd_line is None:
|
111 |
+
return parser.parse_args()
|
112 |
+
else:
|
113 |
+
return parser.parse_args(self.cmd_line)
|
114 |
+
|
115 |
+
def print_options(self, opt):
|
116 |
+
"""Print and save options
|
117 |
+
|
118 |
+
It will print both current options and default values(if different).
|
119 |
+
It will save options into a text file / [checkpoints_dir] / opt.txt
|
120 |
+
"""
|
121 |
+
message = ''
|
122 |
+
message += '----------------- Options ---------------\n'
|
123 |
+
for k, v in sorted(vars(opt).items()):
|
124 |
+
comment = ''
|
125 |
+
default = self.parser.get_default(k)
|
126 |
+
if v != default:
|
127 |
+
comment = '\t[default: %s]' % str(default)
|
128 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
129 |
+
message += '----------------- End -------------------'
|
130 |
+
print(message)
|
131 |
+
|
132 |
+
# save to the disk
|
133 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
134 |
+
util.mkdirs(expr_dir)
|
135 |
+
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
136 |
+
try:
|
137 |
+
with open(file_name, 'wt') as opt_file:
|
138 |
+
opt_file.write(message)
|
139 |
+
opt_file.write('\n')
|
140 |
+
except PermissionError as error:
|
141 |
+
print("permission error {}".format(error))
|
142 |
+
pass
|
143 |
+
|
144 |
+
def parse(self):
|
145 |
+
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
|
146 |
+
opt = self.gather_options()
|
147 |
+
opt.isTrain = self.isTrain # train or test
|
148 |
+
|
149 |
+
# process opt.suffix
|
150 |
+
if opt.suffix:
|
151 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
152 |
+
opt.name = opt.name + suffix
|
153 |
+
|
154 |
+
self.print_options(opt)
|
155 |
+
|
156 |
+
# set gpu ids
|
157 |
+
str_ids = opt.gpu_ids.split(',')
|
158 |
+
opt.gpu_ids = []
|
159 |
+
for str_id in str_ids:
|
160 |
+
id = int(str_id)
|
161 |
+
if id >= 0:
|
162 |
+
opt.gpu_ids.append(id)
|
163 |
+
if len(opt.gpu_ids) > 0:
|
164 |
+
torch.cuda.set_device(opt.gpu_ids[0])
|
165 |
+
|
166 |
+
self.opt = opt
|
167 |
+
return self.opt
|
asp/options/test_options.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_options import BaseOptions
|
2 |
+
|
3 |
+
|
4 |
+
class TestOptions(BaseOptions):
|
5 |
+
"""This class includes test options.
|
6 |
+
|
7 |
+
It also includes shared options defined in BaseOptions.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def initialize(self, parser):
|
11 |
+
parser = BaseOptions.initialize(self, parser) # define shared options
|
12 |
+
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
13 |
+
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
14 |
+
# Dropout and Batchnorm has different behavioir during training and test.
|
15 |
+
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
16 |
+
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
|
17 |
+
|
18 |
+
# To avoid cropping, the load_size should be the same as crop_size
|
19 |
+
parser.set_defaults(load_size=parser.get_default('crop_size'))
|
20 |
+
self.isTrain = False
|
21 |
+
return parser
|
asp/options/train_options.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_options import BaseOptions
|
2 |
+
|
3 |
+
|
4 |
+
class TrainOptions(BaseOptions):
|
5 |
+
"""This class includes training options.
|
6 |
+
|
7 |
+
It also includes shared options defined in BaseOptions.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def initialize(self, parser):
|
11 |
+
parser = BaseOptions.initialize(self, parser)
|
12 |
+
# visdom and HTML visualization parameters
|
13 |
+
parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
|
14 |
+
parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
|
15 |
+
parser.add_argument('--display_id', type=int, default=None, help='window id of the web display. Default is random window id')
|
16 |
+
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
|
17 |
+
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
|
18 |
+
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
|
19 |
+
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
|
20 |
+
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
|
21 |
+
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
|
22 |
+
# network saving and loading parameters
|
23 |
+
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
|
24 |
+
parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
|
25 |
+
parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
|
26 |
+
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
|
27 |
+
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
28 |
+
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
29 |
+
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
30 |
+
parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
|
31 |
+
|
32 |
+
# training parameters
|
33 |
+
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs with the initial learning rate')
|
34 |
+
parser.add_argument('--n_epochs_decay', type=int, default=200, help='number of epochs to linearly decay learning rate to zero')
|
35 |
+
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
|
36 |
+
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
|
37 |
+
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
|
38 |
+
parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
|
39 |
+
parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
|
40 |
+
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
|
41 |
+
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
|
42 |
+
|
43 |
+
self.isTrain = True
|
44 |
+
return parser
|
asp/util/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
"""This package includes a miscellaneous collection of useful helper functions."""
|
2 |
+
from asp.util import *
|
asp/util/fdlutil.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from pylab import *
|
5 |
+
import matplotlib as mpl
|
6 |
+
|
7 |
+
# Use tkAgg when plotting to a window, Agg when to a file
|
8 |
+
# #### mpl.use('TkAgg') # Don't use this unless emergency. More trouble than it's worth
|
9 |
+
mpl.use('Agg')
|
10 |
+
|
11 |
+
|
12 |
+
def quick_imshow(nrows, ncols=1, images=None, titles=None, colorbar=True, colormap='jet',
|
13 |
+
vmax=None, vmin=None, figsize=None, figtitle=None, visibleaxis=True,
|
14 |
+
saveas='/home/ubuntu/tempimshow.png', tight=False, dpi=250.0):
|
15 |
+
"""-------------------------------------------------------------------------
|
16 |
+
Desc.: convenience function that make subplots of imshow
|
17 |
+
Args.: nrows - number of rows
|
18 |
+
ncols - number of cols
|
19 |
+
images - list of images
|
20 |
+
titles - list of titles
|
21 |
+
vmax - tuple of vmax for the colormap. If scalar,
|
22 |
+
the same value is used for all subplots. If one
|
23 |
+
of the entries is None, no colormap for that
|
24 |
+
subplot will be drawn.
|
25 |
+
vmin - tuple of vmin
|
26 |
+
Returns: f - the figure handle
|
27 |
+
axes - axes or array of axes objects
|
28 |
+
caxes - tuple of axes image
|
29 |
+
-------------------------------------------------------------------------"""
|
30 |
+
if isinstance(nrows, np.ndarray):
|
31 |
+
images = nrows
|
32 |
+
nrows = 1
|
33 |
+
ncols = 1
|
34 |
+
|
35 |
+
if figsize == None:
|
36 |
+
# 1.0 translates to 100 pixels of the figure
|
37 |
+
s = 5.0
|
38 |
+
if figtitle:
|
39 |
+
figsize = (s * ncols, s * nrows + 0.5)
|
40 |
+
else:
|
41 |
+
figsize = (s * ncols, s * nrows)
|
42 |
+
|
43 |
+
if nrows == ncols == 1:
|
44 |
+
if isinstance(images, list):
|
45 |
+
images = images[0]
|
46 |
+
f, ax = plt.subplots(figsize=figsize)
|
47 |
+
cax = ax.imshow(images, cmap=colormap, vmax=vmax, vmin=vmin)
|
48 |
+
if colorbar:
|
49 |
+
f.colorbar(cax, ax=ax)
|
50 |
+
if titles != None:
|
51 |
+
ax.set_title(titles)
|
52 |
+
if figtitle != None:
|
53 |
+
f.suptitle(figtitle)
|
54 |
+
cax.axes.get_xaxis().set_visible(visibleaxis)
|
55 |
+
cax.axes.get_yaxis().set_visible(visibleaxis)
|
56 |
+
if tight:
|
57 |
+
plt.tight_layout()
|
58 |
+
if len(saveas) > 0:
|
59 |
+
dirname = os.path.dirname(saveas)
|
60 |
+
if not os.path.exists(dirname):
|
61 |
+
os.makedirs(dirname)
|
62 |
+
plt.savefig(saveas)
|
63 |
+
return f, ax, cax
|
64 |
+
|
65 |
+
f, axes = plt.subplots(nrows, ncols, figsize=figsize, dpi=dpi)
|
66 |
+
caxes = []
|
67 |
+
i = 0
|
68 |
+
for ax, img in zip(axes.flat, images):
|
69 |
+
if isinstance(vmax, tuple) and isinstance(vmin, tuple):
|
70 |
+
if vmax[i] is not None and vmin[i] is not None:
|
71 |
+
cax = ax.imshow(img, cmap=colormap, vmax=vmax[i], vmin=vmin[i])
|
72 |
+
else:
|
73 |
+
cax = ax.imshow(img, cmap=colormap)
|
74 |
+
elif isinstance(vmax, tuple) and vmin is None:
|
75 |
+
if vmax[i] is not None:
|
76 |
+
cax = ax.imshow(img, cmap=colormap, vmax=vmax[i], vmin=0)
|
77 |
+
else:
|
78 |
+
cax = ax.imshow(img, cmap=colormap)
|
79 |
+
elif vmax is None and vmin is None:
|
80 |
+
cax = ax.imshow(img, cmap=colormap)
|
81 |
+
else:
|
82 |
+
cax = ax.imshow(img, cmap=colormap, vmax=vmax, vmin=vmin)
|
83 |
+
if titles != None:
|
84 |
+
ax.set_title(titles[i])
|
85 |
+
if colorbar:
|
86 |
+
f.colorbar(cax, ax=ax)
|
87 |
+
caxes.append(cax)
|
88 |
+
cax.axes.get_xaxis().set_visible(visibleaxis)
|
89 |
+
cax.axes.get_yaxis().set_visible(visibleaxis)
|
90 |
+
i = i + 1
|
91 |
+
if figtitle != None:
|
92 |
+
f.suptitle(figtitle)
|
93 |
+
if tight:
|
94 |
+
plt.tight_layout()
|
95 |
+
if len(saveas) > 0:
|
96 |
+
dirname = os.path.dirname(saveas)
|
97 |
+
if not os.path.exists(dirname):
|
98 |
+
os.makedirs(dirname)
|
99 |
+
plt.savefig(saveas)
|
100 |
+
return f, axes, tuple(caxes)
|
101 |
+
|
102 |
+
|
103 |
+
def update_subplots(images, caxes, f=None, axes=None, indices=(), vmax=None,
|
104 |
+
vmin=None):
|
105 |
+
"""-------------------------------------------------------------------------
|
106 |
+
Desc.: update subplots in a figure
|
107 |
+
Args.: images - new images to plot
|
108 |
+
caxes - caxes returned at figure creation
|
109 |
+
indices - specific indices of subplots to be updated
|
110 |
+
Returns:
|
111 |
+
-------------------------------------------------------------------------"""
|
112 |
+
for i in range(len(images)):
|
113 |
+
if len(indices) > 0:
|
114 |
+
ind = indices[i]
|
115 |
+
else:
|
116 |
+
ind = i
|
117 |
+
img = images[i]
|
118 |
+
caxes[ind].set_data(img)
|
119 |
+
cbar = caxes[ind].colorbar
|
120 |
+
if isinstance(vmax, tuple) and isinstance(vmin, tuple):
|
121 |
+
if vmax[i] is not None and vmin[i] is not None:
|
122 |
+
cbar.set_clim([vmin[i], vmax[i]])
|
123 |
+
else:
|
124 |
+
cbar.set_clim([img.min(), img.max()])
|
125 |
+
elif isinstance(vmax, tuple) and vmin is None:
|
126 |
+
if vmax[i] is not None:
|
127 |
+
cbar.set_clim([0, vmax[i]])
|
128 |
+
else:
|
129 |
+
cbar.set_clim([img.min(), img.max()])
|
130 |
+
elif vmax is None and vmin is None:
|
131 |
+
cbar.set_clim([img.min(), img.max()])
|
132 |
+
else:
|
133 |
+
cbar.set_clim([vmin, vmax])
|
134 |
+
cbar.update_normal(caxes[ind])
|
135 |
+
pause(0.01)
|
136 |
+
tight_layout()
|
137 |
+
|
138 |
+
|
139 |
+
def slide_show(image, dt=0.01, vmax=None, vmin=None):
|
140 |
+
"""
|
141 |
+
Slide show for visualizing an image volume. Image is (w, h, d)
|
142 |
+
:param image: (w, h, d), slides are 2D images along the depth axis
|
143 |
+
:param dt:
|
144 |
+
:param vmax:
|
145 |
+
:param vmin:
|
146 |
+
:return:
|
147 |
+
"""
|
148 |
+
if image.dtype == bool:
|
149 |
+
image *= 1.0
|
150 |
+
if vmax is None:
|
151 |
+
vmax = image.max()
|
152 |
+
if vmin is None:
|
153 |
+
vmin = image.min()
|
154 |
+
plt.ion()
|
155 |
+
plt.figure()
|
156 |
+
for i in range(image.shape[2]):
|
157 |
+
plt.cla()
|
158 |
+
cax = plt.imshow(image[:, :, i], cmap='jet', vmin=vmin, vmax=vmax)
|
159 |
+
plt.title(str('Slice: %i/%i' % (i, image.shape[2] - 1)))
|
160 |
+
if i == 0:
|
161 |
+
cf = plt.gcf()
|
162 |
+
ca = plt.gca()
|
163 |
+
cf.colorbar(cax, ax=ca)
|
164 |
+
plt.pause(dt)
|
165 |
+
plt.draw()
|
166 |
+
|
167 |
+
|
168 |
+
def quick_collage(images, nrows=3, ncols=2, normalize=False, figsize=(20.0, 10.0), figtitle=None, colorbar=True,
|
169 |
+
tight=True, saveas='/home/ubuntu/tempcollage.png'):
|
170 |
+
def zero_to_one(x):
|
171 |
+
if x.min() == x.max():
|
172 |
+
return x - x.min()
|
173 |
+
return (x.astype(float) - x.min()) / (x.max() - x.min())
|
174 |
+
# Normalize every image
|
175 |
+
if isinstance(images, np.ndarray):
|
176 |
+
images = [images]
|
177 |
+
# Check the shape and make sure everything is float
|
178 |
+
img_shp = images[0].shape
|
179 |
+
if normalize:
|
180 |
+
images = [zero_to_one(image) for image in images]
|
181 |
+
vmax, vmin = 1.0, 0.0
|
182 |
+
else:
|
183 |
+
vmax, vmin = max([img.max() for img in images]), min(
|
184 |
+
[img.min() for img in images])
|
185 |
+
# Highlight the boundaries
|
186 |
+
for i in range(0, len(images) - 1):
|
187 |
+
images[i] = np.hstack(
|
188 |
+
[images[i], np.full((img_shp[0], 1, img_shp[2]), np.nan)])
|
189 |
+
collage = np.hstack(images)
|
190 |
+
# Determine slice depth
|
191 |
+
depth = collage.shape[2]
|
192 |
+
n_slices = nrows * ncols
|
193 |
+
z = [int(depth / (n_slices + 1) * i - 1) for i in range(1, (n_slices + 1))]
|
194 |
+
titles = ['Slice %d/%d' % (i, depth) for i in z]
|
195 |
+
quick_imshow(
|
196 |
+
nrows, ncols,
|
197 |
+
[collage[:, :, z[i]] for i in range(n_slices)],
|
198 |
+
titles=titles,
|
199 |
+
figtitle=figtitle,
|
200 |
+
figsize=figsize,
|
201 |
+
vmax=vmax, vmin=vmin,
|
202 |
+
colorbar=colorbar, tight=tight)
|
203 |
+
if len(saveas) > 0:
|
204 |
+
plt.savefig(saveas)
|
205 |
+
plt.close()
|
206 |
+
|
207 |
+
|
208 |
+
def quick_plot(x_data, y_data=None, fmt='', color=None, xlim=None, ylim=None,
|
209 |
+
label='', legends=False, x_label='', y_label='', figtitle='', annotation=None, figsize=(20, 10),
|
210 |
+
f=None, ax=None, saveas=''):
|
211 |
+
if f is None or ax is None:
|
212 |
+
f, ax = subplots(figsize=figsize)
|
213 |
+
if y_data is None:
|
214 |
+
temp = x_data
|
215 |
+
x_data = list(range(len(temp)))
|
216 |
+
y_data = temp
|
217 |
+
ax.plot(x_data, y_data, fmt, label=label, color=color)
|
218 |
+
if xlim is not None:
|
219 |
+
ax.set_xlim(xlim)
|
220 |
+
if ylim is not None:
|
221 |
+
ax.set_ylim(ylim)
|
222 |
+
if annotation is not None:
|
223 |
+
for i in range(len(x_data)):
|
224 |
+
annotate(annotation[i], (x_data[i], y_data[i]),
|
225 |
+
textcoords='offset points', xytext=(0, 10), ha='center')
|
226 |
+
if len(x_label) > 0:
|
227 |
+
ax.set_xlabel(x_label)
|
228 |
+
if len(y_label) > 0:
|
229 |
+
ax.set_ylabel(y_label)
|
230 |
+
if len(figtitle) > 0:
|
231 |
+
f.suptitle(figtitle)
|
232 |
+
if legends:
|
233 |
+
ax.legend(loc='center left', bbox_to_anchor=(1.04, 0.5))
|
234 |
+
ax.grid()
|
235 |
+
if len(saveas) > 0:
|
236 |
+
f.savefig(saveas, bbox_inches='tight')
|
237 |
+
ax.grid()
|
238 |
+
return f, ax
|
239 |
+
|
240 |
+
|
241 |
+
def quick_scatter(x_data, y_data=None, xlim=None, ylim=None,
|
242 |
+
label='', legends=False, x_label='', y_label='', figtitle='', annotation=None,
|
243 |
+
f=None, ax=None, saveas=''):
|
244 |
+
if f is None or ax is None:
|
245 |
+
f, ax = subplots()
|
246 |
+
if y_data is None:
|
247 |
+
temp = x_data
|
248 |
+
x_data = list(range(len(temp)))
|
249 |
+
y_data = temp
|
250 |
+
ax.scatter(x_data, y_data, label=label)
|
251 |
+
if xlim is not None:
|
252 |
+
ax.set_xlim(xlim)
|
253 |
+
if ylim is not None:
|
254 |
+
ax.set_ylim(ylim)
|
255 |
+
if annotation is not None:
|
256 |
+
for i in range(len(x_data)):
|
257 |
+
annotate(annotation[i], (x_data[i], y_data[i]),
|
258 |
+
textcoords='offset points', xytext=(0, 10), ha='center')
|
259 |
+
if len(x_label) > 0:
|
260 |
+
ax.set_xlabel(x_label)
|
261 |
+
if len(y_label) > 0:
|
262 |
+
ax.set_ylabel(y_label)
|
263 |
+
if len(figtitle) > 0:
|
264 |
+
f.suptitle(figtitle)
|
265 |
+
if legends:
|
266 |
+
ax.legend()
|
267 |
+
ax.grid()
|
268 |
+
if len(saveas) > 0:
|
269 |
+
f.savefig(saveas)
|
270 |
+
return f, ax
|
271 |
+
|
272 |
+
|
273 |
+
def quick_load(file_path, fits_field=1):
|
274 |
+
if file_path.endswith('npz'):
|
275 |
+
with load(file_path, allow_pickle=True) as f:
|
276 |
+
data = f['arr_0']
|
277 |
+
# Take care of the case where a dictionary is saved in npz format
|
278 |
+
if isinstance(data, ndarray) and data.dtype == 'O':
|
279 |
+
data = data.flatten()[0]
|
280 |
+
# elif file_path.endswith(('pyc', 'pickle')):
|
281 |
+
# data = pickle_load(file_path)
|
282 |
+
# elif file_path.endswith('fits.gz'):
|
283 |
+
# data = read_fits_data(file_path, fits_field)
|
284 |
+
# elif file_path.endswith('h5'):
|
285 |
+
# data = read_hdf5_data(file_path)
|
286 |
+
else:
|
287 |
+
raise NotImplementedError(
|
288 |
+
"Only npz, pyc, h5 and fits.gz are supported!")
|
289 |
+
return data
|
290 |
+
|
291 |
+
|
292 |
+
def quick_save(file_path, data):
|
293 |
+
dir_name = os.path.dirname(file_path)
|
294 |
+
if not os.path.exists(dir_name):
|
295 |
+
os.makedirs(dir_name)
|
296 |
+
# For better disk utilization and compatibility with fits, use int32
|
297 |
+
if file_path.endswith('npz'):
|
298 |
+
savez_compressed(file_path, data)
|
299 |
+
# elif file_path.endswith(('pyc', 'pickle')):
|
300 |
+
# save_object(file_path, data)
|
301 |
+
# elif file_path.endswith('fits.gz'):
|
302 |
+
# if isinstance(data, ndarray) and data.dtype == int:
|
303 |
+
# data = data.astype(int32)
|
304 |
+
# save_fits_data(file_path, data)
|
305 |
+
# elif file_path.endswith('h5'):
|
306 |
+
# write_hdf5_data(file_path, data)
|
307 |
+
else:
|
308 |
+
raise NotImplementedError(
|
309 |
+
"Only npz, pyc, h5 and fits.gz are supported!")
|
310 |
+
|
311 |
+
|
312 |
+
def import_module(name, path):
|
313 |
+
"""
|
314 |
+
correct way of importing a module dynamically in python 3.
|
315 |
+
:param name: name given to module instance.
|
316 |
+
:param path: path to module.
|
317 |
+
:return: module: returned module instance.
|
318 |
+
"""
|
319 |
+
spec = importlib.util.spec_from_file_location(name, path)
|
320 |
+
module = importlib.util.module_from_spec(spec)
|
321 |
+
spec.loader.exec_module(module)
|
322 |
+
return module
|
323 |
+
|
324 |
+
|
325 |
+
def obj_from_dict(info, parent=None, default_args=None):
|
326 |
+
"""Initialize an object from dict.
|
327 |
+
The dict must contain the key "type", which indicates the object type, it
|
328 |
+
can be either a string or type, such as "list" or ``list``. Remaining
|
329 |
+
fields are treated as the arguments for constructing the object.
|
330 |
+
Args:
|
331 |
+
info (dict): Object types and arguments.
|
332 |
+
parent (:class:`module`): Module which may containing expected object
|
333 |
+
classes.
|
334 |
+
default_args (dict, optional): Default arguments for initializing the
|
335 |
+
object.
|
336 |
+
Returns:
|
337 |
+
any type: Object built from the dict.
|
338 |
+
"""
|
339 |
+
assert isinstance(info, dict) and 'type' in info
|
340 |
+
assert isinstance(default_args, dict) or default_args is None
|
341 |
+
args = info.copy()
|
342 |
+
obj_type = args.pop('type')
|
343 |
+
if isinstance(obj_type, str):
|
344 |
+
if parent is not None:
|
345 |
+
obj_type = getattr(parent, obj_type)
|
346 |
+
else:
|
347 |
+
obj_type = sys.modules[obj_type]
|
348 |
+
elif not isinstance(obj_type, type):
|
349 |
+
raise TypeError('type must be a str or valid type, but '
|
350 |
+
f'got {type(obj_type)}')
|
351 |
+
if default_args is not None:
|
352 |
+
for name, value in default_args.items():
|
353 |
+
args.setdefault(name, value)
|
354 |
+
return obj_type(**args)
|
355 |
+
|
356 |
+
|
357 |
+
def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
|
358 |
+
"""
|
359 |
+
one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
|
360 |
+
:param image: nd image. can be anything
|
361 |
+
:param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
|
362 |
+
len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
|
363 |
+
the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
|
364 |
+
Example:
|
365 |
+
image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
|
366 |
+
image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
|
367 |
+
:param mode: see np.pad for documentation
|
368 |
+
:param return_slicer: if True then this function will also return what coords you will need to use when cropping back
|
369 |
+
to original shape
|
370 |
+
:param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
|
371 |
+
divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
|
372 |
+
be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
|
373 |
+
:param kwargs: see np.pad for documentation
|
374 |
+
"""
|
375 |
+
if kwargs is None:
|
376 |
+
kwargs = {}
|
377 |
+
|
378 |
+
if new_shape is not None:
|
379 |
+
old_shape = np.array(image.shape[-len(new_shape):])
|
380 |
+
else:
|
381 |
+
assert shape_must_be_divisible_by is not None
|
382 |
+
assert isinstance(shape_must_be_divisible_by,
|
383 |
+
(list, tuple, np.ndarray))
|
384 |
+
new_shape = image.shape[-len(shape_must_be_divisible_by):]
|
385 |
+
old_shape = new_shape
|
386 |
+
|
387 |
+
num_axes_nopad = len(image.shape) - len(new_shape)
|
388 |
+
|
389 |
+
new_shape = [max(new_shape[i], old_shape[i])
|
390 |
+
for i in range(len(new_shape))]
|
391 |
+
|
392 |
+
if not isinstance(new_shape, np.ndarray):
|
393 |
+
new_shape = np.array(new_shape)
|
394 |
+
|
395 |
+
if shape_must_be_divisible_by is not None:
|
396 |
+
if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
|
397 |
+
shape_must_be_divisible_by = [
|
398 |
+
shape_must_be_divisible_by] * len(new_shape)
|
399 |
+
else:
|
400 |
+
assert len(shape_must_be_divisible_by) == len(new_shape)
|
401 |
+
|
402 |
+
for i in range(len(new_shape)):
|
403 |
+
if new_shape[i] % shape_must_be_divisible_by[i] == 0:
|
404 |
+
new_shape[i] -= shape_must_be_divisible_by[i]
|
405 |
+
|
406 |
+
new_shape = np.array(
|
407 |
+
[new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in
|
408 |
+
range(len(new_shape))])
|
409 |
+
|
410 |
+
difference = new_shape - old_shape
|
411 |
+
pad_below = difference // 2
|
412 |
+
pad_above = difference // 2 + difference % 2
|
413 |
+
pad_list = [[0, 0]] * num_axes_nopad + \
|
414 |
+
list([list(i) for i in zip(pad_below, pad_above)])
|
415 |
+
res = np.pad(image, pad_list, mode, **kwargs)
|
416 |
+
if not return_slicer:
|
417 |
+
return res
|
418 |
+
else:
|
419 |
+
pad_list = np.array(pad_list)
|
420 |
+
pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
|
421 |
+
slicer = list(slice(*i) for i in pad_list)
|
422 |
+
return res, slicer
|
asp/util/fid.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
2 |
+
|
3 |
+
The FID metric calculates the distance between two distributions of images.
|
4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
6 |
+
|
7 |
+
When run as a stand-alone program, it compares the distribution of
|
8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
9 |
+
distribution given by summary statistics (in pickle format).
|
10 |
+
|
11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
13 |
+
samples respectively.
|
14 |
+
|
15 |
+
See --help to see further details.
|
16 |
+
|
17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
18 |
+
of Tensorflow
|
19 |
+
|
20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
21 |
+
|
22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
you may not use this file except in compliance with the License.
|
24 |
+
You may obtain a copy of the License at
|
25 |
+
|
26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
|
28 |
+
Unless required by applicable law or agreed to in writing, software
|
29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
See the License for the specific language governing permissions and
|
32 |
+
limitations under the License.
|
33 |
+
"""
|
34 |
+
import os
|
35 |
+
import pathlib
|
36 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
37 |
+
|
38 |
+
import numpy as np
|
39 |
+
import torch
|
40 |
+
import torchvision.transforms as TF
|
41 |
+
from PIL import Image
|
42 |
+
from scipy import linalg
|
43 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
44 |
+
|
45 |
+
try:
|
46 |
+
from tqdm import tqdm
|
47 |
+
except ImportError:
|
48 |
+
# If tqdm is not available, provide a mock version of it
|
49 |
+
def tqdm(x):
|
50 |
+
return x
|
51 |
+
|
52 |
+
from util.inception import InceptionV3
|
53 |
+
|
54 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
55 |
+
parser.add_argument('--batch-size', type=int, default=50,
|
56 |
+
help='Batch size to use')
|
57 |
+
parser.add_argument('--num-workers', type=int,
|
58 |
+
help=('Number of processes to use for data loading. '
|
59 |
+
'Defaults to `min(8, num_cpus)`'))
|
60 |
+
parser.add_argument('--device', type=str, default=None,
|
61 |
+
help='Device to use. Like cuda, cuda:0 or cpu')
|
62 |
+
parser.add_argument('--dims', type=int, default=2048,
|
63 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
64 |
+
help=('Dimensionality of Inception features to use. '
|
65 |
+
'By default, uses pool3 features'))
|
66 |
+
parser.add_argument('path', type=str, nargs=2,
|
67 |
+
help=('Paths to the generated images or '
|
68 |
+
'to .npz statistic files'))
|
69 |
+
|
70 |
+
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
71 |
+
'tif', 'tiff', 'webp'}
|
72 |
+
|
73 |
+
|
74 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
75 |
+
def __init__(self, files, transforms=None):
|
76 |
+
self.files = files
|
77 |
+
self.transforms = transforms
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.files)
|
81 |
+
|
82 |
+
def __getitem__(self, i):
|
83 |
+
path = self.files[i]
|
84 |
+
img = Image.open(path).convert('RGB')
|
85 |
+
if self.transforms is not None:
|
86 |
+
img = self.transforms(img)
|
87 |
+
return img
|
88 |
+
|
89 |
+
|
90 |
+
def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
|
91 |
+
num_workers=1):
|
92 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
93 |
+
|
94 |
+
Params:
|
95 |
+
-- files : List of image files paths
|
96 |
+
-- model : Instance of inception model
|
97 |
+
-- batch_size : Batch size of images for the model to process at once.
|
98 |
+
Make sure that the number of samples is a multiple of
|
99 |
+
the batch size, otherwise some samples are ignored. This
|
100 |
+
behavior is retained to match the original FID score
|
101 |
+
implementation.
|
102 |
+
-- dims : Dimensionality of features returned by Inception
|
103 |
+
-- device : Device to run calculations
|
104 |
+
-- num_workers : Number of parallel dataloader workers
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
108 |
+
activations of the given tensor when feeding inception with the
|
109 |
+
query tensor.
|
110 |
+
"""
|
111 |
+
model.eval()
|
112 |
+
|
113 |
+
if batch_size > len(files):
|
114 |
+
print(('Warning: batch size is bigger than the data size. '
|
115 |
+
'Setting batch size to data size'))
|
116 |
+
batch_size = len(files)
|
117 |
+
|
118 |
+
dataset = ImagePathDataset(files, transforms=TF.ToTensor())
|
119 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
120 |
+
batch_size=batch_size,
|
121 |
+
shuffle=False,
|
122 |
+
drop_last=False,
|
123 |
+
num_workers=num_workers)
|
124 |
+
|
125 |
+
pred_arr = np.empty((len(files), dims))
|
126 |
+
|
127 |
+
start_idx = 0
|
128 |
+
|
129 |
+
for batch in tqdm(dataloader):
|
130 |
+
batch = batch.to(device)
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
pred = model(batch)[0]
|
134 |
+
|
135 |
+
# If model output is not scalar, apply global spatial average pooling.
|
136 |
+
# This happens if you choose a dimensionality not equal 2048.
|
137 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
138 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
139 |
+
|
140 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
141 |
+
|
142 |
+
pred_arr[start_idx:start_idx + pred.shape[0]] = pred
|
143 |
+
|
144 |
+
start_idx = start_idx + pred.shape[0]
|
145 |
+
|
146 |
+
return pred_arr
|
147 |
+
|
148 |
+
|
149 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
150 |
+
"""Numpy implementation of the Frechet Distance.
|
151 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
152 |
+
and X_2 ~ N(mu_2, C_2) is
|
153 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
154 |
+
|
155 |
+
Stable version by Dougal J. Sutherland.
|
156 |
+
|
157 |
+
Params:
|
158 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
159 |
+
inception net (like returned by the function 'get_predictions')
|
160 |
+
for generated samples.
|
161 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
162 |
+
representative data set.
|
163 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
164 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
165 |
+
representative data set.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
-- : The Frechet Distance.
|
169 |
+
"""
|
170 |
+
|
171 |
+
mu1 = np.atleast_1d(mu1)
|
172 |
+
mu2 = np.atleast_1d(mu2)
|
173 |
+
|
174 |
+
sigma1 = np.atleast_2d(sigma1)
|
175 |
+
sigma2 = np.atleast_2d(sigma2)
|
176 |
+
|
177 |
+
assert mu1.shape == mu2.shape, \
|
178 |
+
'Training and test mean vectors have different lengths'
|
179 |
+
assert sigma1.shape == sigma2.shape, \
|
180 |
+
'Training and test covariances have different dimensions'
|
181 |
+
|
182 |
+
diff = mu1 - mu2
|
183 |
+
|
184 |
+
# Product might be almost singular
|
185 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
186 |
+
if not np.isfinite(covmean).all():
|
187 |
+
msg = ('fid calculation produces singular product; '
|
188 |
+
'adding %s to diagonal of cov estimates') % eps
|
189 |
+
print(msg)
|
190 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
191 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
192 |
+
|
193 |
+
# Numerical error might give slight imaginary component
|
194 |
+
if np.iscomplexobj(covmean):
|
195 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
196 |
+
m = np.max(np.abs(covmean.imag))
|
197 |
+
raise ValueError('Imaginary component {}'.format(m))
|
198 |
+
covmean = covmean.real
|
199 |
+
|
200 |
+
tr_covmean = np.trace(covmean)
|
201 |
+
|
202 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
203 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
204 |
+
|
205 |
+
|
206 |
+
def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
|
207 |
+
device='cpu', num_workers=1):
|
208 |
+
"""Calculation of the statistics used by the FID.
|
209 |
+
Params:
|
210 |
+
-- files : List of image files paths
|
211 |
+
-- model : Instance of inception model
|
212 |
+
-- batch_size : The images numpy array is split into batches with
|
213 |
+
batch size batch_size. A reasonable batch size
|
214 |
+
depends on the hardware.
|
215 |
+
-- dims : Dimensionality of features returned by Inception
|
216 |
+
-- device : Device to run calculations
|
217 |
+
-- num_workers : Number of parallel dataloader workers
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
221 |
+
the inception model.
|
222 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
223 |
+
the inception model.
|
224 |
+
"""
|
225 |
+
act = get_activations(files, model, batch_size, dims, device, num_workers)
|
226 |
+
mu = np.mean(act, axis=0)
|
227 |
+
sigma = np.cov(act, rowvar=False)
|
228 |
+
return mu, sigma
|
229 |
+
|
230 |
+
|
231 |
+
def compute_statistics_of_path(path, model, batch_size, dims, device,
|
232 |
+
num_workers=1):
|
233 |
+
if path.endswith('.npz'):
|
234 |
+
with np.load(path) as f:
|
235 |
+
m, s = f['mu'][:], f['sigma'][:]
|
236 |
+
else:
|
237 |
+
path = pathlib.Path(path)
|
238 |
+
files = sorted([file for ext in IMAGE_EXTENSIONS
|
239 |
+
for file in path.glob('*.{}'.format(ext))])
|
240 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
241 |
+
dims, device, num_workers)
|
242 |
+
|
243 |
+
return m, s
|
244 |
+
|
245 |
+
|
246 |
+
def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
|
247 |
+
"""Calculates the FID of two paths"""
|
248 |
+
for p in paths:
|
249 |
+
if not os.path.exists(p):
|
250 |
+
raise RuntimeError('Invalid path: %s' % p)
|
251 |
+
|
252 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
253 |
+
|
254 |
+
model = InceptionV3([block_idx]).to(device)
|
255 |
+
|
256 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
257 |
+
dims, device, num_workers)
|
258 |
+
m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
|
259 |
+
dims, device, num_workers)
|
260 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
261 |
+
|
262 |
+
return fid_value
|
263 |
+
|
264 |
+
|
265 |
+
def main():
|
266 |
+
args = parser.parse_args()
|
267 |
+
|
268 |
+
if args.device is None:
|
269 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
270 |
+
else:
|
271 |
+
device = torch.device(args.device)
|
272 |
+
|
273 |
+
if args.num_workers is None:
|
274 |
+
num_avail_cpus = len(os.sched_getaffinity(0))
|
275 |
+
num_workers = min(num_avail_cpus, 8)
|
276 |
+
else:
|
277 |
+
num_workers = args.num_workers
|
278 |
+
|
279 |
+
fid_value = calculate_fid_given_paths(args.path,
|
280 |
+
args.batch_size,
|
281 |
+
device,
|
282 |
+
args.dims,
|
283 |
+
num_workers)
|
284 |
+
print('FID: ', fid_value)
|
285 |
+
|
286 |
+
|
287 |
+
if __name__ == '__main__':
|
288 |
+
main()
|
asp/util/general_utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(".")
|
3 |
+
|
4 |
+
import time, os
|
5 |
+
from functools import wraps
|
6 |
+
import argparse
|
7 |
+
import inspect
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
def time_it(func):
|
11 |
+
@wraps(func)
|
12 |
+
def wrapper(*args, **kwargs):
|
13 |
+
start_time = time.time()
|
14 |
+
result = func(*args, **kwargs)
|
15 |
+
end_time = time.time()
|
16 |
+
elapsed_time = end_time - start_time
|
17 |
+
print(f"Function '{func.__name__}' executed in {elapsed_time:.6f} seconds.")
|
18 |
+
return result
|
19 |
+
return wrapper
|
20 |
+
|
21 |
+
def try_wrapper(function, filename, log_path):
|
22 |
+
try:
|
23 |
+
return function()
|
24 |
+
except Exception as e:
|
25 |
+
error_trace = traceback.format_exc()
|
26 |
+
|
27 |
+
with open(log_path, 'a') as log_file:
|
28 |
+
log_file.write(f"{filename}: {error_trace}\n")
|
29 |
+
print(f"Error in {filename}:\n{error_trace}")
|
30 |
+
|
31 |
+
def parse_args(main_function):
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
|
34 |
+
used_short_versions = set("h")
|
35 |
+
|
36 |
+
signature = inspect.signature(main_function)
|
37 |
+
for param_name, param in signature.parameters.items():
|
38 |
+
short_version = param_name[0]
|
39 |
+
if short_version in used_short_versions or not short_version.isalpha():
|
40 |
+
for char in param_name[1:]:
|
41 |
+
short_version = char
|
42 |
+
if char.isalpha() and short_version not in used_short_versions:
|
43 |
+
break
|
44 |
+
else:
|
45 |
+
short_version = None
|
46 |
+
|
47 |
+
if short_version:
|
48 |
+
used_short_versions.add(short_version)
|
49 |
+
param_call = (f'-{short_version}', f'--{param_name}')
|
50 |
+
else:
|
51 |
+
param_call = (f'--{param_name}',)
|
52 |
+
|
53 |
+
if param.default is not inspect.Parameter.empty:
|
54 |
+
if param.default is not None:
|
55 |
+
param_type = type(param.default)
|
56 |
+
else:
|
57 |
+
param_type = str
|
58 |
+
parser.add_argument(*param_call, type=param_type, default=param.default,
|
59 |
+
help=f"Automatically detected argument: {param_name}, default: {param.default}")
|
60 |
+
else:
|
61 |
+
parser.add_argument(*param_call, required=True,
|
62 |
+
help=f"Required argument: {param_name}")
|
63 |
+
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
return args
|
67 |
+
|
68 |
+
def assert_file_exist(*args):
|
69 |
+
path = os.path.join(*args)
|
70 |
+
if not os.path.exists(path):
|
71 |
+
raise Exception(f"File {path} does not exist")
|
72 |
+
|
73 |
+
return path
|
asp/util/get_data.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import requests
|
5 |
+
from warnings import warn
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from os.path import abspath, isdir, join, basename
|
9 |
+
|
10 |
+
|
11 |
+
class GetData(object):
|
12 |
+
"""A Python script for downloading CycleGAN or pix2pix datasets.
|
13 |
+
|
14 |
+
Parameters:
|
15 |
+
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
|
16 |
+
verbose (bool) -- If True, print additional information.
|
17 |
+
|
18 |
+
Examples:
|
19 |
+
>>> from util.get_data import GetData
|
20 |
+
>>> gd = GetData(technique='cyclegan')
|
21 |
+
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
|
22 |
+
|
23 |
+
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
|
24 |
+
and 'scripts/download_cyclegan_model.sh'.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, technique='cyclegan', verbose=True):
|
28 |
+
url_dict = {
|
29 |
+
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
|
30 |
+
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
|
31 |
+
}
|
32 |
+
self.url = url_dict.get(technique.lower())
|
33 |
+
self._verbose = verbose
|
34 |
+
|
35 |
+
def _print(self, text):
|
36 |
+
if self._verbose:
|
37 |
+
print(text)
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def _get_options(r):
|
41 |
+
soup = BeautifulSoup(r.text, 'lxml')
|
42 |
+
options = [h.text for h in soup.find_all('a', href=True)
|
43 |
+
if h.text.endswith(('.zip', 'tar.gz'))]
|
44 |
+
return options
|
45 |
+
|
46 |
+
def _present_options(self):
|
47 |
+
r = requests.get(self.url)
|
48 |
+
options = self._get_options(r)
|
49 |
+
print('Options:\n')
|
50 |
+
for i, o in enumerate(options):
|
51 |
+
print("{0}: {1}".format(i, o))
|
52 |
+
choice = input("\nPlease enter the number of the "
|
53 |
+
"dataset above you wish to download:")
|
54 |
+
return options[int(choice)]
|
55 |
+
|
56 |
+
def _download_data(self, dataset_url, save_path):
|
57 |
+
if not isdir(save_path):
|
58 |
+
os.makedirs(save_path)
|
59 |
+
|
60 |
+
base = basename(dataset_url)
|
61 |
+
temp_save_path = join(save_path, base)
|
62 |
+
|
63 |
+
with open(temp_save_path, "wb") as f:
|
64 |
+
r = requests.get(dataset_url)
|
65 |
+
f.write(r.content)
|
66 |
+
|
67 |
+
if base.endswith('.tar.gz'):
|
68 |
+
obj = tarfile.open(temp_save_path)
|
69 |
+
elif base.endswith('.zip'):
|
70 |
+
obj = ZipFile(temp_save_path, 'r')
|
71 |
+
else:
|
72 |
+
raise ValueError("Unknown File Type: {0}.".format(base))
|
73 |
+
|
74 |
+
self._print("Unpacking Data...")
|
75 |
+
obj.extractall(save_path)
|
76 |
+
obj.close()
|
77 |
+
os.remove(temp_save_path)
|
78 |
+
|
79 |
+
def get(self, save_path, dataset=None):
|
80 |
+
"""
|
81 |
+
|
82 |
+
Download a dataset.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
save_path (str) -- A directory to save the data to.
|
86 |
+
dataset (str) -- (optional). A specific dataset to download.
|
87 |
+
Note: this must include the file extension.
|
88 |
+
If None, options will be presented for you
|
89 |
+
to choose from.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
save_path_full (str) -- the absolute path to the downloaded data.
|
93 |
+
|
94 |
+
"""
|
95 |
+
if dataset is None:
|
96 |
+
selected_dataset = self._present_options()
|
97 |
+
else:
|
98 |
+
selected_dataset = dataset
|
99 |
+
|
100 |
+
save_path_full = join(save_path, selected_dataset.split('.')[0])
|
101 |
+
|
102 |
+
if isdir(save_path_full):
|
103 |
+
warn("\n'{0}' already exists. Voiding Download.".format(
|
104 |
+
save_path_full))
|
105 |
+
else:
|
106 |
+
self._print('Downloading Data...')
|
107 |
+
url = "{0}/{1}".format(self.url, selected_dataset)
|
108 |
+
self._download_data(url, save_path=save_path)
|
109 |
+
|
110 |
+
return abspath(save_path_full)
|
asp/util/html.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
8 |
+
|
9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, web_dir, title, refresh=0):
|
15 |
+
"""Initialize the HTML classes
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
19 |
+
title (str) -- the webpage name
|
20 |
+
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
21 |
+
"""
|
22 |
+
self.title = title
|
23 |
+
self.web_dir = web_dir
|
24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
25 |
+
if not os.path.exists(self.web_dir):
|
26 |
+
os.makedirs(self.web_dir)
|
27 |
+
if not os.path.exists(self.img_dir):
|
28 |
+
os.makedirs(self.img_dir)
|
29 |
+
|
30 |
+
self.doc = dominate.document(title=title)
|
31 |
+
if refresh > 0:
|
32 |
+
with self.doc.head:
|
33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
34 |
+
|
35 |
+
def get_image_dir(self):
|
36 |
+
"""Return the directory that stores images"""
|
37 |
+
return self.img_dir
|
38 |
+
|
39 |
+
def add_header(self, text):
|
40 |
+
"""Insert a header to the HTML file
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
text (str) -- the header text
|
44 |
+
"""
|
45 |
+
with self.doc:
|
46 |
+
h3(text)
|
47 |
+
|
48 |
+
def add_images(self, ims, txts, links, width=400):
|
49 |
+
"""add images to the HTML file
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
ims (str list) -- a list of image paths
|
53 |
+
txts (str list) -- a list of image names shown on the website
|
54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
55 |
+
"""
|
56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
57 |
+
self.doc.add(self.t)
|
58 |
+
with self.t:
|
59 |
+
with tr():
|
60 |
+
for im, txt, link in zip(ims, txts, links):
|
61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
62 |
+
with p():
|
63 |
+
with a(href=os.path.join('images', link)):
|
64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
65 |
+
br()
|
66 |
+
p(txt)
|
67 |
+
|
68 |
+
def save(self):
|
69 |
+
"""save the current content to the HMTL file"""
|
70 |
+
html_file = '%s/index.html' % self.web_dir
|
71 |
+
f = open(html_file, 'wt')
|
72 |
+
f.write(self.doc.render())
|
73 |
+
f.close()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__': # we show an example usage here.
|
77 |
+
html = HTML('web/', 'test_html')
|
78 |
+
html.add_header('hello world')
|
79 |
+
|
80 |
+
ims, txts, links = [], [], []
|
81 |
+
for n in range(4):
|
82 |
+
ims.append('image_%d.png' % n)
|
83 |
+
txts.append('text_%d' % n)
|
84 |
+
links.append('image_%d.png' % n)
|
85 |
+
html.add_images(ims, txts, links)
|
86 |
+
html.save()
|
asp/util/image_pool.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ImagePool():
|
6 |
+
"""This class implements an image buffer that stores previously generated images.
|
7 |
+
|
8 |
+
This buffer enables us to update discriminators using a history of generated images
|
9 |
+
rather than the ones produced by the latest generators.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, pool_size):
|
13 |
+
"""Initialize the ImagePool class
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
17 |
+
"""
|
18 |
+
self.pool_size = pool_size
|
19 |
+
if self.pool_size > 0: # create an empty pool
|
20 |
+
self.num_imgs = 0
|
21 |
+
self.images = []
|
22 |
+
|
23 |
+
def query(self, images):
|
24 |
+
"""Return an image from the pool.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
images: the latest generated images from the generator
|
28 |
+
|
29 |
+
Returns images from the buffer.
|
30 |
+
|
31 |
+
By 50/100, the buffer will return input images.
|
32 |
+
By 50/100, the buffer will return images previously stored in the buffer,
|
33 |
+
and insert the current images to the buffer.
|
34 |
+
"""
|
35 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
36 |
+
return images
|
37 |
+
return_images = []
|
38 |
+
for image in images:
|
39 |
+
image = torch.unsqueeze(image.data, 0)
|
40 |
+
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
|
41 |
+
self.num_imgs = self.num_imgs + 1
|
42 |
+
self.images.append(image)
|
43 |
+
return_images.append(image)
|
44 |
+
else:
|
45 |
+
p = random.uniform(0, 1)
|
46 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
|
47 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
48 |
+
tmp = self.images[random_id].clone()
|
49 |
+
self.images[random_id] = image
|
50 |
+
return_images.append(tmp)
|
51 |
+
else: # by another 50% chance, the buffer will return the current image
|
52 |
+
return_images.append(image)
|
53 |
+
return_images = torch.cat(return_images, 0) # collect all the images and return
|
54 |
+
return return_images
|
asp/util/inception.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3 # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
33 |
+
resize_input=True,
|
34 |
+
normalize_input=True,
|
35 |
+
requires_grad=False,
|
36 |
+
use_fid_inception=True):
|
37 |
+
"""Build pretrained InceptionV3
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
output_blocks : list of int
|
42 |
+
Indices of blocks to return features of. Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input : bool
|
48 |
+
If true, bilinearly resizes input to width and height 299 before
|
49 |
+
feeding input to model. As the network without fully connected
|
50 |
+
layers is fully convolutional, it should be able to handle inputs
|
51 |
+
of arbitrary size, so resizing might not be strictly needed
|
52 |
+
normalize_input : bool
|
53 |
+
If true, scales the input from range (0, 1) to the range the
|
54 |
+
pretrained Inception network expects, namely (-1, 1)
|
55 |
+
requires_grad : bool
|
56 |
+
If true, parameters of the model require gradients. Possibly useful
|
57 |
+
for finetuning the network
|
58 |
+
use_fid_inception : bool
|
59 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
60 |
+
FID implementation. If false, uses the pretrained Inception model
|
61 |
+
available in torchvision. The FID Inception model has different
|
62 |
+
weights and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get comparable
|
65 |
+
results.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, \
|
75 |
+
'Last possible output block index is 3'
|
76 |
+
|
77 |
+
self.blocks = nn.ModuleList()
|
78 |
+
|
79 |
+
if use_fid_inception:
|
80 |
+
inception = fid_inception_v3()
|
81 |
+
else:
|
82 |
+
inception = _inception_v3(pretrained=True)
|
83 |
+
|
84 |
+
# Block 0: input to maxpool1
|
85 |
+
block0 = [
|
86 |
+
inception.Conv2d_1a_3x3,
|
87 |
+
inception.Conv2d_2a_3x3,
|
88 |
+
inception.Conv2d_2b_3x3,
|
89 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
90 |
+
]
|
91 |
+
self.blocks.append(nn.Sequential(*block0))
|
92 |
+
|
93 |
+
# Block 1: maxpool1 to maxpool2
|
94 |
+
if self.last_needed_block >= 1:
|
95 |
+
block1 = [
|
96 |
+
inception.Conv2d_3b_1x1,
|
97 |
+
inception.Conv2d_4a_3x3,
|
98 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
99 |
+
]
|
100 |
+
self.blocks.append(nn.Sequential(*block1))
|
101 |
+
|
102 |
+
# Block 2: maxpool2 to aux classifier
|
103 |
+
if self.last_needed_block >= 2:
|
104 |
+
block2 = [
|
105 |
+
inception.Mixed_5b,
|
106 |
+
inception.Mixed_5c,
|
107 |
+
inception.Mixed_5d,
|
108 |
+
inception.Mixed_6a,
|
109 |
+
inception.Mixed_6b,
|
110 |
+
inception.Mixed_6c,
|
111 |
+
inception.Mixed_6d,
|
112 |
+
inception.Mixed_6e,
|
113 |
+
]
|
114 |
+
self.blocks.append(nn.Sequential(*block2))
|
115 |
+
|
116 |
+
# Block 3: aux classifier to final avgpool
|
117 |
+
if self.last_needed_block >= 3:
|
118 |
+
block3 = [
|
119 |
+
inception.Mixed_7a,
|
120 |
+
inception.Mixed_7b,
|
121 |
+
inception.Mixed_7c,
|
122 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
123 |
+
]
|
124 |
+
self.blocks.append(nn.Sequential(*block3))
|
125 |
+
|
126 |
+
for param in self.parameters():
|
127 |
+
param.requires_grad = requires_grad
|
128 |
+
|
129 |
+
def forward(self, inp):
|
130 |
+
"""Get Inception feature maps
|
131 |
+
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
inp : torch.autograd.Variable
|
135 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
136 |
+
range (0, 1)
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
141 |
+
block, sorted ascending by index
|
142 |
+
"""
|
143 |
+
outp = []
|
144 |
+
x = inp
|
145 |
+
|
146 |
+
if self.resize_input:
|
147 |
+
x = F.interpolate(x,
|
148 |
+
size=(299, 299),
|
149 |
+
mode='bilinear',
|
150 |
+
align_corners=False)
|
151 |
+
|
152 |
+
if self.normalize_input:
|
153 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
154 |
+
|
155 |
+
for idx, block in enumerate(self.blocks):
|
156 |
+
x = block(x)
|
157 |
+
if idx in self.output_blocks:
|
158 |
+
outp.append(x)
|
159 |
+
|
160 |
+
if idx == self.last_needed_block:
|
161 |
+
break
|
162 |
+
|
163 |
+
return outp
|
164 |
+
|
165 |
+
|
166 |
+
def _inception_v3(*args, **kwargs):
|
167 |
+
"""Wraps `torchvision.models.inception_v3`
|
168 |
+
|
169 |
+
Skips default weight inititialization if supported by torchvision version.
|
170 |
+
See https://github.com/mseitzer/pytorch-fid/issues/28.
|
171 |
+
"""
|
172 |
+
try:
|
173 |
+
version = tuple(map(int, torchvision.__version__.split('.')[:2]))
|
174 |
+
except ValueError:
|
175 |
+
# Just a caution against weird version strings
|
176 |
+
version = (0,)
|
177 |
+
|
178 |
+
if version >= (0, 6):
|
179 |
+
kwargs['init_weights'] = False
|
180 |
+
|
181 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
182 |
+
|
183 |
+
|
184 |
+
def fid_inception_v3():
|
185 |
+
"""Build pretrained Inception model for FID computation
|
186 |
+
|
187 |
+
The Inception model for FID computation uses a different set of weights
|
188 |
+
and has a slightly different structure than torchvision's Inception.
|
189 |
+
|
190 |
+
This method first constructs torchvision's Inception and then patches the
|
191 |
+
necessary parts that are different in the FID Inception model.
|
192 |
+
"""
|
193 |
+
inception = _inception_v3(num_classes=1008,
|
194 |
+
aux_logits=False,
|
195 |
+
pretrained=False)
|
196 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
197 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
198 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
199 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
200 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
201 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
202 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
203 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
204 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
205 |
+
|
206 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
207 |
+
inception.load_state_dict(state_dict)
|
208 |
+
return inception
|
209 |
+
|
210 |
+
|
211 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
212 |
+
"""InceptionA block patched for FID computation"""
|
213 |
+
def __init__(self, in_channels, pool_features):
|
214 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
branch1x1 = self.branch1x1(x)
|
218 |
+
|
219 |
+
branch5x5 = self.branch5x5_1(x)
|
220 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
221 |
+
|
222 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
223 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
224 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
225 |
+
|
226 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
227 |
+
# its average calculation
|
228 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
229 |
+
count_include_pad=False)
|
230 |
+
branch_pool = self.branch_pool(branch_pool)
|
231 |
+
|
232 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
233 |
+
return torch.cat(outputs, 1)
|
234 |
+
|
235 |
+
|
236 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
237 |
+
"""InceptionC block patched for FID computation"""
|
238 |
+
def __init__(self, in_channels, channels_7x7):
|
239 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
branch1x1 = self.branch1x1(x)
|
243 |
+
|
244 |
+
branch7x7 = self.branch7x7_1(x)
|
245 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
246 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
247 |
+
|
248 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
249 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
250 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
251 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
252 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
253 |
+
|
254 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
255 |
+
# its average calculation
|
256 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
257 |
+
count_include_pad=False)
|
258 |
+
branch_pool = self.branch_pool(branch_pool)
|
259 |
+
|
260 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
261 |
+
return torch.cat(outputs, 1)
|
262 |
+
|
263 |
+
|
264 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
265 |
+
"""First InceptionE block patched for FID computation"""
|
266 |
+
def __init__(self, in_channels):
|
267 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
branch1x1 = self.branch1x1(x)
|
271 |
+
|
272 |
+
branch3x3 = self.branch3x3_1(x)
|
273 |
+
branch3x3 = [
|
274 |
+
self.branch3x3_2a(branch3x3),
|
275 |
+
self.branch3x3_2b(branch3x3),
|
276 |
+
]
|
277 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
278 |
+
|
279 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
280 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
281 |
+
branch3x3dbl = [
|
282 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
283 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
284 |
+
]
|
285 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
286 |
+
|
287 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
288 |
+
# its average calculation
|
289 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
290 |
+
count_include_pad=False)
|
291 |
+
branch_pool = self.branch_pool(branch_pool)
|
292 |
+
|
293 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
294 |
+
return torch.cat(outputs, 1)
|
295 |
+
|
296 |
+
|
297 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
298 |
+
"""Second InceptionE block patched for FID computation"""
|
299 |
+
def __init__(self, in_channels):
|
300 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
branch1x1 = self.branch1x1(x)
|
304 |
+
|
305 |
+
branch3x3 = self.branch3x3_1(x)
|
306 |
+
branch3x3 = [
|
307 |
+
self.branch3x3_2a(branch3x3),
|
308 |
+
self.branch3x3_2b(branch3x3),
|
309 |
+
]
|
310 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
311 |
+
|
312 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
313 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
314 |
+
branch3x3dbl = [
|
315 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
316 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
317 |
+
]
|
318 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
319 |
+
|
320 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
321 |
+
# pooling. This is likely an error in this specific Inception
|
322 |
+
# implementation, as other Inception models use average pooling here
|
323 |
+
# (which matches the description in the paper).
|
324 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
325 |
+
branch_pool = self.branch_pool(branch_pool)
|
326 |
+
|
327 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
328 |
+
return torch.cat(outputs, 1)
|
asp/util/kid_score.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""Calculates the Kernel Inception Distance (KID) to evalulate GANs
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import sys
|
7 |
+
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from sklearn.metrics.pairwise import polynomial_kernel
|
12 |
+
from scipy import linalg
|
13 |
+
from PIL import Image
|
14 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
15 |
+
|
16 |
+
try:
|
17 |
+
from tqdm import tqdm
|
18 |
+
except ImportError:
|
19 |
+
# If not tqdm is not available, provide a mock version of it
|
20 |
+
def tqdm(x): return x
|
21 |
+
|
22 |
+
# from models.inception import InceptionV3
|
23 |
+
# from models.lenet import LeNet5
|
24 |
+
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
from torchvision import models
|
28 |
+
|
29 |
+
|
30 |
+
class InceptionV3(nn.Module):
|
31 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
32 |
+
|
33 |
+
# Index of default block of inception to return,
|
34 |
+
# corresponds to output of final average pooling
|
35 |
+
DEFAULT_BLOCK_INDEX = 3
|
36 |
+
|
37 |
+
# Maps feature dimensionality to their output blocks indices
|
38 |
+
BLOCK_INDEX_BY_DIM = {
|
39 |
+
64: 0, # First max pooling features
|
40 |
+
192: 1, # Second max pooling featurs
|
41 |
+
768: 2, # Pre-aux classifier features
|
42 |
+
2048: 3 # Final average pooling features
|
43 |
+
}
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
47 |
+
resize_input=True,
|
48 |
+
normalize_input=True,
|
49 |
+
requires_grad=False):
|
50 |
+
"""Build pretrained InceptionV3
|
51 |
+
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
output_blocks : list of int
|
55 |
+
Indices of blocks to return features of. Possible values are:
|
56 |
+
- 0: corresponds to output of first max pooling
|
57 |
+
- 1: corresponds to output of second max pooling
|
58 |
+
- 2: corresponds to output which is fed to aux classifier
|
59 |
+
- 3: corresponds to output of final average pooling
|
60 |
+
resize_input : bool
|
61 |
+
If true, bilinearly resizes input to width and height 299 before
|
62 |
+
feeding input to model. As the network without fully connected
|
63 |
+
layers is fully convolutional, it should be able to handle inputs
|
64 |
+
of arbitrary size, so resizing might not be strictly needed
|
65 |
+
normalize_input : bool
|
66 |
+
If true, scales the input from range (0, 1) to the range the
|
67 |
+
pretrained Inception network expects, namely (-1, 1)
|
68 |
+
requires_grad : bool
|
69 |
+
If true, parameters of the model require gradient. Possibly useful
|
70 |
+
for finetuning the network
|
71 |
+
"""
|
72 |
+
super(InceptionV3, self).__init__()
|
73 |
+
|
74 |
+
self.resize_input = resize_input
|
75 |
+
self.normalize_input = normalize_input
|
76 |
+
self.output_blocks = sorted(output_blocks)
|
77 |
+
self.last_needed_block = max(output_blocks)
|
78 |
+
|
79 |
+
assert self.last_needed_block <= 3, \
|
80 |
+
'Last possible output block index is 3'
|
81 |
+
|
82 |
+
self.blocks = nn.ModuleList()
|
83 |
+
|
84 |
+
inception = models.inception_v3(pretrained=True)
|
85 |
+
|
86 |
+
# Block 0: input to maxpool1
|
87 |
+
block0 = [
|
88 |
+
inception.Conv2d_1a_3x3,
|
89 |
+
inception.Conv2d_2a_3x3,
|
90 |
+
inception.Conv2d_2b_3x3,
|
91 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
92 |
+
]
|
93 |
+
self.blocks.append(nn.Sequential(*block0))
|
94 |
+
|
95 |
+
# Block 1: maxpool1 to maxpool2
|
96 |
+
if self.last_needed_block >= 1:
|
97 |
+
block1 = [
|
98 |
+
inception.Conv2d_3b_1x1,
|
99 |
+
inception.Conv2d_4a_3x3,
|
100 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
101 |
+
]
|
102 |
+
self.blocks.append(nn.Sequential(*block1))
|
103 |
+
|
104 |
+
# Block 2: maxpool2 to aux classifier
|
105 |
+
if self.last_needed_block >= 2:
|
106 |
+
block2 = [
|
107 |
+
inception.Mixed_5b,
|
108 |
+
inception.Mixed_5c,
|
109 |
+
inception.Mixed_5d,
|
110 |
+
inception.Mixed_6a,
|
111 |
+
inception.Mixed_6b,
|
112 |
+
inception.Mixed_6c,
|
113 |
+
inception.Mixed_6d,
|
114 |
+
inception.Mixed_6e,
|
115 |
+
]
|
116 |
+
self.blocks.append(nn.Sequential(*block2))
|
117 |
+
|
118 |
+
# Block 3: aux classifier to final avgpool
|
119 |
+
if self.last_needed_block >= 3:
|
120 |
+
block3 = [
|
121 |
+
inception.Mixed_7a,
|
122 |
+
inception.Mixed_7b,
|
123 |
+
inception.Mixed_7c,
|
124 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
125 |
+
]
|
126 |
+
self.blocks.append(nn.Sequential(*block3))
|
127 |
+
|
128 |
+
for param in self.parameters():
|
129 |
+
param.requires_grad = requires_grad
|
130 |
+
|
131 |
+
def forward(self, inp):
|
132 |
+
"""Get Inception feature maps
|
133 |
+
|
134 |
+
Parameters
|
135 |
+
----------
|
136 |
+
inp : torch.autograd.Variable
|
137 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
138 |
+
range (0.0, 1.0)
|
139 |
+
|
140 |
+
Returns
|
141 |
+
-------
|
142 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
143 |
+
block, sorted ascending by index
|
144 |
+
"""
|
145 |
+
outp = []
|
146 |
+
x = inp
|
147 |
+
|
148 |
+
if self.resize_input:
|
149 |
+
x = F.interpolate(x,
|
150 |
+
size=(299, 299),
|
151 |
+
mode='bilinear',
|
152 |
+
align_corners=False)
|
153 |
+
|
154 |
+
if self.normalize_input:
|
155 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
156 |
+
|
157 |
+
for idx, block in enumerate(self.blocks):
|
158 |
+
x = block(x)
|
159 |
+
if idx in self.output_blocks:
|
160 |
+
outp.append(x)
|
161 |
+
|
162 |
+
if idx == self.last_needed_block:
|
163 |
+
break
|
164 |
+
|
165 |
+
return outp
|
166 |
+
|
167 |
+
|
168 |
+
def get_activations(files, model, batch_size=50, dims=2048,
|
169 |
+
cuda=False, verbose=False):
|
170 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
171 |
+
|
172 |
+
Params:
|
173 |
+
-- files : List of image files paths
|
174 |
+
-- model : Instance of inception model
|
175 |
+
-- batch_size : Batch size of images for the model to process at once.
|
176 |
+
Make sure that the number of samples is a multiple of
|
177 |
+
the batch size, otherwise some samples are ignored. This
|
178 |
+
behavior is retained to match the original FID score
|
179 |
+
implementation.
|
180 |
+
-- dims : Dimensionality of features returned by Inception
|
181 |
+
-- cuda : If set to True, use GPU
|
182 |
+
-- verbose : If set to True and parameter out_step is given, the number
|
183 |
+
of calculated batches is reported.
|
184 |
+
Returns:
|
185 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
186 |
+
activations of the given tensor when feeding inception with the
|
187 |
+
query tensor.
|
188 |
+
"""
|
189 |
+
model.eval()
|
190 |
+
|
191 |
+
is_numpy = True if type(files[0]) == np.ndarray else False
|
192 |
+
|
193 |
+
if len(files) % batch_size != 0:
|
194 |
+
print(('Warning: number of images is not a multiple of the '
|
195 |
+
'batch size. Some samples are going to be ignored.'))
|
196 |
+
if batch_size > len(files):
|
197 |
+
print(('Warning: batch size is bigger than the data size. '
|
198 |
+
'Setting batch size to data size'))
|
199 |
+
batch_size = len(files)
|
200 |
+
|
201 |
+
n_batches = len(files) // batch_size
|
202 |
+
n_used_imgs = n_batches * batch_size
|
203 |
+
|
204 |
+
pred_arr = np.empty((n_used_imgs, dims))
|
205 |
+
|
206 |
+
for i in tqdm(range(n_batches)):
|
207 |
+
if verbose:
|
208 |
+
print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)
|
209 |
+
start = i * batch_size
|
210 |
+
end = start + batch_size
|
211 |
+
if is_numpy:
|
212 |
+
images = np.copy(files[start:end]) + 1
|
213 |
+
images /= 2.
|
214 |
+
else:
|
215 |
+
images = [np.array(Image.open(str(f))) for f in files[start:end]]
|
216 |
+
images = np.stack(images).astype(np.float32) / 255.
|
217 |
+
# Reshape to (n_images, 3, height, width)
|
218 |
+
images = images.transpose((0, 3, 1, 2))
|
219 |
+
|
220 |
+
batch = torch.from_numpy(images).type(torch.FloatTensor)
|
221 |
+
if cuda:
|
222 |
+
batch = batch.cuda()
|
223 |
+
pred = model(batch)[0]
|
224 |
+
|
225 |
+
# If model output is not scalar, apply global spatial average pooling.
|
226 |
+
# This happens if you choose a dimensionality not equal 2048.
|
227 |
+
if pred.shape[2] != 1 or pred.shape[3] != 1:
|
228 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
229 |
+
|
230 |
+
pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
|
231 |
+
|
232 |
+
if verbose:
|
233 |
+
print('done', np.min(images))
|
234 |
+
|
235 |
+
return pred_arr
|
236 |
+
|
237 |
+
|
238 |
+
def extract_lenet_features(imgs, net):
|
239 |
+
net.eval()
|
240 |
+
feats = []
|
241 |
+
imgs = imgs.reshape([-1, 100] + list(imgs.shape[1:]))
|
242 |
+
if imgs[0].min() < -0.001:
|
243 |
+
imgs = (imgs + 1)/2.0
|
244 |
+
print(imgs.shape, imgs.min(), imgs.max())
|
245 |
+
imgs = torch.from_numpy(imgs)
|
246 |
+
for i, images in enumerate(imgs):
|
247 |
+
feats.append(net.extract_features(images).detach().cpu().numpy())
|
248 |
+
feats = np.vstack(feats)
|
249 |
+
return feats
|
250 |
+
|
251 |
+
|
252 |
+
def _compute_activations(path, model, batch_size, dims, cuda, model_type):
|
253 |
+
if not type(path) == np.ndarray:
|
254 |
+
import glob
|
255 |
+
jpg = os.path.join(path, '*.jpg')
|
256 |
+
png = os.path.join(path, '*.png')
|
257 |
+
path = glob.glob(jpg) + glob.glob(png)
|
258 |
+
if len(path) > 50000:
|
259 |
+
import random
|
260 |
+
random.shuffle(path)
|
261 |
+
path = path[:50000]
|
262 |
+
if model_type == 'inception':
|
263 |
+
act = get_activations(path, model, batch_size, dims, cuda)
|
264 |
+
elif model_type == 'lenet':
|
265 |
+
act = extract_lenet_features(path, model)
|
266 |
+
return act
|
267 |
+
|
268 |
+
|
269 |
+
def calculate_kid_given_paths(paths, batch_size, cuda, dims, model_type='inception'):
|
270 |
+
"""Calculates the KID of two paths"""
|
271 |
+
pths = []
|
272 |
+
for p in paths:
|
273 |
+
if not os.path.exists(p):
|
274 |
+
raise RuntimeError('Invalid path: %s' % p)
|
275 |
+
if os.path.isdir(p):
|
276 |
+
pths.append(p)
|
277 |
+
elif p.endswith('.npy'):
|
278 |
+
np_imgs = np.load(p)
|
279 |
+
if np_imgs.shape[0] > 50000: np_imgs = np_imgs[np.random.permutation(np.arange(np_imgs.shape[0]))][:50000]
|
280 |
+
pths.append(np_imgs)
|
281 |
+
|
282 |
+
if model_type == 'inception':
|
283 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
284 |
+
model = InceptionV3([block_idx])
|
285 |
+
elif model_type == 'lenet':
|
286 |
+
model = LeNet5()
|
287 |
+
model.load_state_dict(torch.load('./models/lenet.pth'))
|
288 |
+
if cuda:
|
289 |
+
model.cuda()
|
290 |
+
|
291 |
+
act_true = _compute_activations(pths[0], model, batch_size, dims, cuda, model_type)
|
292 |
+
pths = pths[1:]
|
293 |
+
results = []
|
294 |
+
for j, pth in enumerate(pths):
|
295 |
+
print(paths[j+1])
|
296 |
+
actj = _compute_activations(pth, model, batch_size, dims, cuda, model_type)
|
297 |
+
kid_values = polynomial_mmd_averages(act_true, actj, n_subsets=100, subset_size=min(act_true.shape[0], 100))
|
298 |
+
results.append((paths[j+1], kid_values[0].mean(), kid_values[0].std()))
|
299 |
+
return results
|
300 |
+
|
301 |
+
def _sqn(arr):
|
302 |
+
flat = np.ravel(arr)
|
303 |
+
return flat.dot(flat)
|
304 |
+
|
305 |
+
|
306 |
+
def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000,
|
307 |
+
ret_var=True, output=sys.stdout, **kernel_args):
|
308 |
+
m = min(codes_g.shape[0], codes_r.shape[0])
|
309 |
+
mmds = np.zeros(n_subsets)
|
310 |
+
if ret_var:
|
311 |
+
vars = np.zeros(n_subsets)
|
312 |
+
choice = np.random.choice
|
313 |
+
|
314 |
+
with tqdm(range(n_subsets), desc='MMD', file=output) as bar:
|
315 |
+
for i in bar:
|
316 |
+
g = codes_g[choice(len(codes_g), subset_size, replace=False)]
|
317 |
+
r = codes_r[choice(len(codes_r), subset_size, replace=False)]
|
318 |
+
o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var)
|
319 |
+
if ret_var:
|
320 |
+
mmds[i], vars[i] = o
|
321 |
+
else:
|
322 |
+
mmds[i] = o
|
323 |
+
bar.set_postfix({'mean': mmds[:i+1].mean()})
|
324 |
+
return (mmds, vars) if ret_var else mmds
|
325 |
+
|
326 |
+
|
327 |
+
def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1,
|
328 |
+
var_at_m=None, ret_var=True):
|
329 |
+
# use k(x, y) = (gamma <x, y> + coef0)^degree
|
330 |
+
# default gamma is 1 / dim
|
331 |
+
X = codes_g
|
332 |
+
Y = codes_r
|
333 |
+
|
334 |
+
K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
|
335 |
+
K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
|
336 |
+
K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)
|
337 |
+
|
338 |
+
return _mmd2_and_variance(K_XX, K_XY, K_YY,
|
339 |
+
var_at_m=var_at_m, ret_var=ret_var)
|
340 |
+
|
341 |
+
def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
|
342 |
+
mmd_est='unbiased', block_size=1024,
|
343 |
+
var_at_m=None, ret_var=True):
|
344 |
+
# based on
|
345 |
+
# https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
|
346 |
+
# but changed to not compute the full kernel matrix at once
|
347 |
+
m = K_XX.shape[0]
|
348 |
+
assert K_XX.shape == (m, m)
|
349 |
+
assert K_XY.shape == (m, m)
|
350 |
+
assert K_YY.shape == (m, m)
|
351 |
+
if var_at_m is None:
|
352 |
+
var_at_m = m
|
353 |
+
|
354 |
+
# Get the various sums of kernels that we'll use
|
355 |
+
# Kts drop the diagonal, but we don't need to compute them explicitly
|
356 |
+
if unit_diagonal:
|
357 |
+
diag_X = diag_Y = 1
|
358 |
+
sum_diag_X = sum_diag_Y = m
|
359 |
+
sum_diag2_X = sum_diag2_Y = m
|
360 |
+
else:
|
361 |
+
diag_X = np.diagonal(K_XX)
|
362 |
+
diag_Y = np.diagonal(K_YY)
|
363 |
+
|
364 |
+
sum_diag_X = diag_X.sum()
|
365 |
+
sum_diag_Y = diag_Y.sum()
|
366 |
+
|
367 |
+
sum_diag2_X = _sqn(diag_X)
|
368 |
+
sum_diag2_Y = _sqn(diag_Y)
|
369 |
+
|
370 |
+
Kt_XX_sums = K_XX.sum(axis=1) - diag_X
|
371 |
+
Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
|
372 |
+
K_XY_sums_0 = K_XY.sum(axis=0)
|
373 |
+
K_XY_sums_1 = K_XY.sum(axis=1)
|
374 |
+
|
375 |
+
Kt_XX_sum = Kt_XX_sums.sum()
|
376 |
+
Kt_YY_sum = Kt_YY_sums.sum()
|
377 |
+
K_XY_sum = K_XY_sums_0.sum()
|
378 |
+
|
379 |
+
if mmd_est == 'biased':
|
380 |
+
mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
|
381 |
+
+ (Kt_YY_sum + sum_diag_Y) / (m * m)
|
382 |
+
- 2 * K_XY_sum / (m * m))
|
383 |
+
else:
|
384 |
+
assert mmd_est in {'unbiased', 'u-statistic'}
|
385 |
+
mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1))
|
386 |
+
if mmd_est == 'unbiased':
|
387 |
+
mmd2 -= 2 * K_XY_sum / (m * m)
|
388 |
+
else:
|
389 |
+
mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1))
|
390 |
+
|
391 |
+
if not ret_var:
|
392 |
+
return mmd2
|
393 |
+
|
394 |
+
Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
|
395 |
+
Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
|
396 |
+
K_XY_2_sum = _sqn(K_XY)
|
397 |
+
|
398 |
+
dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
|
399 |
+
dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)
|
400 |
+
|
401 |
+
m1 = m - 1
|
402 |
+
m2 = m - 2
|
403 |
+
zeta1_est = (
|
404 |
+
1 / (m * m1 * m2) * (
|
405 |
+
_sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
|
406 |
+
- 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
|
407 |
+
+ 1 / (m * m * m1) * (
|
408 |
+
_sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
|
409 |
+
- 2 / m**4 * K_XY_sum**2
|
410 |
+
- 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
|
411 |
+
+ 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
|
412 |
+
)
|
413 |
+
zeta2_est = (
|
414 |
+
1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
|
415 |
+
- 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
|
416 |
+
+ 2 / (m * m) * K_XY_2_sum
|
417 |
+
- 2 / m**4 * K_XY_sum**2
|
418 |
+
- 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
|
419 |
+
+ 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
|
420 |
+
)
|
421 |
+
var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
|
422 |
+
+ 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)
|
423 |
+
|
424 |
+
return mmd2, var_est
|
425 |
+
|
426 |
+
|
427 |
+
if __name__ == '__main__':
|
428 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
429 |
+
parser.add_argument('--true', type=str, required=True,
|
430 |
+
help=('Path to the true images'))
|
431 |
+
parser.add_argument('--fake', type=str, nargs='+', required=True,
|
432 |
+
help=('Path to the generated images'))
|
433 |
+
parser.add_argument('--batch-size', type=int, default=50,
|
434 |
+
help='Batch size to use')
|
435 |
+
parser.add_argument('--dims', type=int, default=2048,
|
436 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
437 |
+
help=('Dimensionality of Inception features to use. '
|
438 |
+
'By default, uses pool3 features'))
|
439 |
+
parser.add_argument('-c', '--gpu', default='', type=str,
|
440 |
+
help='GPU to use (leave blank for CPU only)')
|
441 |
+
parser.add_argument('--model', default='inception', type=str,
|
442 |
+
help='inception or lenet')
|
443 |
+
args = parser.parse_args()
|
444 |
+
print(args)
|
445 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
446 |
+
paths = [args.true] + args.fake
|
447 |
+
|
448 |
+
results = calculate_kid_given_paths(paths, args.batch_size, args.gpu != '', args.dims, model_type=args.model)
|
449 |
+
for p, m, s in results:
|
450 |
+
print('KID mean std (%s): %.4f %.4f' % (p, m, s))
|
asp/util/perceptual.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
5 |
+
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
def apply_imagenet_normalization(input):
|
13 |
+
r"""Normalize using ImageNet mean and std.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1].
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
Normalized inputs using the ImageNet normalization.
|
20 |
+
"""
|
21 |
+
# normalize the input back to [0, 1]
|
22 |
+
normalized_input = (input + 1) / 2
|
23 |
+
# normalize the input using the ImageNet mean and std
|
24 |
+
mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
25 |
+
std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
26 |
+
output = (normalized_input - mean) / std
|
27 |
+
return output
|
28 |
+
|
29 |
+
|
30 |
+
class PerceptualHashValue(nn.Module):
|
31 |
+
"""Perceptual loss initialization.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
cfg (Config): Configuration file.
|
35 |
+
network (str) : The name of the loss network: 'vgg16' | 'vgg19'.
|
36 |
+
layers (str or list of str) : The layers used to compute the loss.
|
37 |
+
weights (float or list of float : The loss weights of each layer.
|
38 |
+
criterion (str): The type of distance function: 'l1' | 'l2'.
|
39 |
+
resize (bool) : If ``True``, resize the input images to 224x224.
|
40 |
+
resize_mode (str): Algorithm used for resizing.
|
41 |
+
instance_normalized (bool): If ``True``, applies instance normalization
|
42 |
+
to the feature maps before computing the distance.
|
43 |
+
num_scales (int): The loss will be evaluated at original size and
|
44 |
+
this many times downsampled sizes.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, T=0.005, network='vgg19', layers='relu_4_1', resize=False, resize_mode='bilinear',
|
48 |
+
instance_normalized=False):
|
49 |
+
super().__init__()
|
50 |
+
if isinstance(layers, str):
|
51 |
+
layers = [layers]
|
52 |
+
|
53 |
+
if network == 'vgg19':
|
54 |
+
self.model = _vgg19(layers)
|
55 |
+
elif network == 'vgg16':
|
56 |
+
self.model = _vgg16(layers)
|
57 |
+
elif network == 'alexnet':
|
58 |
+
self.model = _alexnet(layers)
|
59 |
+
elif network == 'inception_v3':
|
60 |
+
self.model = _inception_v3(layers)
|
61 |
+
elif network == 'resnet50':
|
62 |
+
self.model = _resnet50(layers)
|
63 |
+
elif network == 'robust_resnet50':
|
64 |
+
self.model = _robust_resnet50(layers)
|
65 |
+
elif network == 'vgg_face_dag':
|
66 |
+
self.model = _vgg_face_dag(layers)
|
67 |
+
else:
|
68 |
+
raise ValueError('Network %s is not recognized' % network)
|
69 |
+
|
70 |
+
self.T = T
|
71 |
+
self.layers = layers
|
72 |
+
self.resize = resize
|
73 |
+
self.resize_mode = resize_mode
|
74 |
+
self.instance_normalized = instance_normalized
|
75 |
+
print('Perceptual Hash Value:')
|
76 |
+
print('\tMode: {}'.format(network))
|
77 |
+
|
78 |
+
def forward(self, inp, target):
|
79 |
+
r"""Perceptual loss forward.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
inp (4D tensor) : Input tensor.
|
83 |
+
target (4D tensor) : Ground truth tensor, same shape as the input.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
(scalar tensor) : The perceptual loss.
|
87 |
+
"""
|
88 |
+
# Perceptual loss should operate in eval mode by default.
|
89 |
+
self.model.eval()
|
90 |
+
inp, target = \
|
91 |
+
apply_imagenet_normalization(inp), \
|
92 |
+
apply_imagenet_normalization(target)
|
93 |
+
if self.resize:
|
94 |
+
inp = F.interpolate(
|
95 |
+
inp, mode=self.resize_mode, size=(224, 224),
|
96 |
+
align_corners=False)
|
97 |
+
target = F.interpolate(
|
98 |
+
target, mode=self.resize_mode, size=(224, 224),
|
99 |
+
align_corners=False)
|
100 |
+
|
101 |
+
# Evaluate perceptual loss at each scale.
|
102 |
+
loss = 0
|
103 |
+
input_features, target_features = \
|
104 |
+
self.model(inp), self.model(target)
|
105 |
+
|
106 |
+
hpv_list = []
|
107 |
+
for layer in self.layers:
|
108 |
+
# Example per-layer VGG19 loss values after applying
|
109 |
+
# [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting.
|
110 |
+
# relu_1_1, 0.014698, 0.47
|
111 |
+
# relu_2_1, 0.085817, 1.37
|
112 |
+
# relu_3_1, 0.349977, 2.8
|
113 |
+
# relu_4_1, 0.544188, 2.176
|
114 |
+
# relu_5_1, 0.906261, 0.906
|
115 |
+
input_feature = input_features[layer]
|
116 |
+
target_feature = target_features[layer].detach()
|
117 |
+
if self.instance_normalized:
|
118 |
+
input_feature = F.instance_norm(input_feature)
|
119 |
+
target_feature = F.instance_norm(target_feature)
|
120 |
+
|
121 |
+
# We are ignoring the spatial dimensions
|
122 |
+
B, C = input_feature.shape[:2]
|
123 |
+
inp_avg = torch.mean(input_feature.view(B, C, -1), -1)
|
124 |
+
tgt_avg = torch.mean(target_feature.view(B, C, -1), -1)
|
125 |
+
abs_dif = torch.abs(inp_avg - tgt_avg)
|
126 |
+
hpv = torch.sum(abs_dif > self.T).item() / (B * C)
|
127 |
+
hpv_list.append(hpv)
|
128 |
+
|
129 |
+
return hpv_list
|
130 |
+
|
131 |
+
|
132 |
+
class _PerceptualNetwork(nn.Module):
|
133 |
+
r"""The network that extracts features to compute the perceptual loss.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
network (nn.Sequential) : The network that extracts features.
|
137 |
+
layer_name_mapping (dict) : The dictionary that
|
138 |
+
maps a layer's index to its name.
|
139 |
+
layers (list of str): The list of layer names that we are using.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, network, layer_name_mapping, layers):
|
143 |
+
super().__init__()
|
144 |
+
assert isinstance(network, nn.Sequential), \
|
145 |
+
'The network needs to be of type "nn.Sequential".'
|
146 |
+
self.network = network
|
147 |
+
self.layer_name_mapping = layer_name_mapping
|
148 |
+
self.layers = layers
|
149 |
+
for param in self.parameters():
|
150 |
+
param.requires_grad = False
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
r"""Extract perceptual features."""
|
154 |
+
output = {}
|
155 |
+
for i, layer in enumerate(self.network):
|
156 |
+
x = layer(x)
|
157 |
+
layer_name = self.layer_name_mapping.get(i, None)
|
158 |
+
if layer_name in self.layers:
|
159 |
+
# If the current layer is used by the perceptual loss.
|
160 |
+
output[layer_name] = x
|
161 |
+
return output
|
162 |
+
|
163 |
+
|
164 |
+
def _vgg19(layers):
|
165 |
+
r"""Get vgg19 layers"""
|
166 |
+
network = torchvision.models.vgg19(pretrained=True).features
|
167 |
+
layer_name_mapping = {1: 'relu_1_1',
|
168 |
+
3: 'relu_1_2',
|
169 |
+
6: 'relu_2_1',
|
170 |
+
8: 'relu_2_2',
|
171 |
+
11: 'relu_3_1',
|
172 |
+
13: 'relu_3_2',
|
173 |
+
15: 'relu_3_3',
|
174 |
+
17: 'relu_3_4',
|
175 |
+
20: 'relu_4_1',
|
176 |
+
22: 'relu_4_2',
|
177 |
+
24: 'relu_4_3',
|
178 |
+
26: 'relu_4_4',
|
179 |
+
29: 'relu_5_1'}
|
180 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
181 |
+
|
182 |
+
|
183 |
+
def _vgg16(layers):
|
184 |
+
r"""Get vgg16 layers"""
|
185 |
+
network = torchvision.models.vgg16(pretrained=True).features
|
186 |
+
layer_name_mapping = {1: 'relu_1_1',
|
187 |
+
3: 'relu_1_2',
|
188 |
+
6: 'relu_2_1',
|
189 |
+
8: 'relu_2_2',
|
190 |
+
11: 'relu_3_1',
|
191 |
+
13: 'relu_3_2',
|
192 |
+
15: 'relu_3_3',
|
193 |
+
18: 'relu_4_1',
|
194 |
+
20: 'relu_4_2',
|
195 |
+
22: 'relu_4_3',
|
196 |
+
25: 'relu_5_1'}
|
197 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
198 |
+
|
199 |
+
|
200 |
+
def _alexnet(layers):
|
201 |
+
r"""Get alexnet layers"""
|
202 |
+
network = torchvision.models.alexnet(pretrained=True).features
|
203 |
+
layer_name_mapping = {0: 'conv_1',
|
204 |
+
1: 'relu_1',
|
205 |
+
3: 'conv_2',
|
206 |
+
4: 'relu_2',
|
207 |
+
6: 'conv_3',
|
208 |
+
7: 'relu_3',
|
209 |
+
8: 'conv_4',
|
210 |
+
9: 'relu_4',
|
211 |
+
10: 'conv_5',
|
212 |
+
11: 'relu_5'}
|
213 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
214 |
+
|
215 |
+
|
216 |
+
def _inception_v3(layers):
|
217 |
+
r"""Get inception v3 layers"""
|
218 |
+
inception = torchvision.models.inception_v3(pretrained=True)
|
219 |
+
network = nn.Sequential(inception.Conv2d_1a_3x3,
|
220 |
+
inception.Conv2d_2a_3x3,
|
221 |
+
inception.Conv2d_2b_3x3,
|
222 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
223 |
+
inception.Conv2d_3b_1x1,
|
224 |
+
inception.Conv2d_4a_3x3,
|
225 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
226 |
+
inception.Mixed_5b,
|
227 |
+
inception.Mixed_5c,
|
228 |
+
inception.Mixed_5d,
|
229 |
+
inception.Mixed_6a,
|
230 |
+
inception.Mixed_6b,
|
231 |
+
inception.Mixed_6c,
|
232 |
+
inception.Mixed_6d,
|
233 |
+
inception.Mixed_6e,
|
234 |
+
inception.Mixed_7a,
|
235 |
+
inception.Mixed_7b,
|
236 |
+
inception.Mixed_7c,
|
237 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1)))
|
238 |
+
layer_name_mapping = {3: 'pool_1',
|
239 |
+
6: 'pool_2',
|
240 |
+
14: 'mixed_6e',
|
241 |
+
18: 'pool_3'}
|
242 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
243 |
+
|
244 |
+
|
245 |
+
def _resnet50(layers):
|
246 |
+
r"""Get resnet50 layers"""
|
247 |
+
resnet50 = torchvision.models.resnet50(pretrained=True)
|
248 |
+
network = nn.Sequential(resnet50.conv1,
|
249 |
+
resnet50.bn1,
|
250 |
+
resnet50.relu,
|
251 |
+
resnet50.maxpool,
|
252 |
+
resnet50.layer1,
|
253 |
+
resnet50.layer2,
|
254 |
+
resnet50.layer3,
|
255 |
+
resnet50.layer4,
|
256 |
+
resnet50.avgpool)
|
257 |
+
layer_name_mapping = {4: 'layer_1',
|
258 |
+
5: 'layer_2',
|
259 |
+
6: 'layer_3',
|
260 |
+
7: 'layer_4'}
|
261 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
262 |
+
|
263 |
+
|
264 |
+
def _robust_resnet50(layers):
|
265 |
+
r"""Get robust resnet50 layers"""
|
266 |
+
resnet50 = torchvision.models.resnet50(pretrained=False)
|
267 |
+
state_dict = torch.utils.model_zoo.load_url(
|
268 |
+
'http://andrewilyas.com/ImageNet.pt')
|
269 |
+
new_state_dict = {}
|
270 |
+
for k, v in state_dict['model'].items():
|
271 |
+
if k.startswith('module.model.'):
|
272 |
+
new_state_dict[k[13:]] = v
|
273 |
+
resnet50.load_state_dict(new_state_dict)
|
274 |
+
network = nn.Sequential(resnet50.conv1,
|
275 |
+
resnet50.bn1,
|
276 |
+
resnet50.relu,
|
277 |
+
resnet50.maxpool,
|
278 |
+
resnet50.layer1,
|
279 |
+
resnet50.layer2,
|
280 |
+
resnet50.layer3,
|
281 |
+
resnet50.layer4,
|
282 |
+
resnet50.avgpool)
|
283 |
+
layer_name_mapping = {4: 'layer_1',
|
284 |
+
5: 'layer_2',
|
285 |
+
6: 'layer_3',
|
286 |
+
7: 'layer_4'}
|
287 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
288 |
+
|
289 |
+
|
290 |
+
def _vgg_face_dag(layers):
|
291 |
+
r"""Get vgg face layers"""
|
292 |
+
network = torchvision.models.vgg16(num_classes=2622)
|
293 |
+
state_dict = torch.utils.model_zoo.load_url(
|
294 |
+
'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/'
|
295 |
+
'vgg_face_dag.pth')
|
296 |
+
feature_layer_name_mapping = {
|
297 |
+
0: 'conv1_1',
|
298 |
+
2: 'conv1_2',
|
299 |
+
5: 'conv2_1',
|
300 |
+
7: 'conv2_2',
|
301 |
+
10: 'conv3_1',
|
302 |
+
12: 'conv3_2',
|
303 |
+
14: 'conv3_3',
|
304 |
+
17: 'conv4_1',
|
305 |
+
19: 'conv4_2',
|
306 |
+
21: 'conv4_3',
|
307 |
+
24: 'conv5_1',
|
308 |
+
26: 'conv5_2',
|
309 |
+
28: 'conv5_3'}
|
310 |
+
new_state_dict = {}
|
311 |
+
for k, v in feature_layer_name_mapping.items():
|
312 |
+
new_state_dict['features.' + str(k) + '.weight'] =\
|
313 |
+
state_dict[v + '.weight']
|
314 |
+
new_state_dict['features.' + str(k) + '.bias'] = \
|
315 |
+
state_dict[v + '.bias']
|
316 |
+
|
317 |
+
classifier_layer_name_mapping = {
|
318 |
+
0: 'fc6',
|
319 |
+
3: 'fc7',
|
320 |
+
6: 'fc8'}
|
321 |
+
for k, v in classifier_layer_name_mapping.items():
|
322 |
+
new_state_dict['classifier.' + str(k) + '.weight'] = \
|
323 |
+
state_dict[v + '.weight']
|
324 |
+
new_state_dict['classifier.' + str(k) + '.bias'] = \
|
325 |
+
state_dict[v + '.bias']
|
326 |
+
|
327 |
+
network.load_state_dict(new_state_dict)
|
328 |
+
|
329 |
+
class Flatten(nn.Module):
|
330 |
+
r"""Flatten the tensor"""
|
331 |
+
|
332 |
+
def forward(self, x):
|
333 |
+
r"""Flatten it"""
|
334 |
+
return x.view(x.shape[0], -1)
|
335 |
+
|
336 |
+
layer_name_mapping = {
|
337 |
+
1: 'avgpool',
|
338 |
+
3: 'fc6',
|
339 |
+
4: 'relu_6',
|
340 |
+
6: 'fc7',
|
341 |
+
7: 'relu_7',
|
342 |
+
9: 'fc8'}
|
343 |
+
seq_layers = [network.features, network.avgpool, Flatten()]
|
344 |
+
for i in range(7):
|
345 |
+
seq_layers += [network.classifier[i]]
|
346 |
+
network = nn.Sequential(*seq_layers)
|
347 |
+
return _PerceptualNetwork(network, layer_name_mapping, layers)
|
asp/util/util.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains simple helper functions """
|
2 |
+
from __future__ import print_function
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
import importlib
|
8 |
+
import argparse
|
9 |
+
from argparse import Namespace
|
10 |
+
import torchvision
|
11 |
+
import cv2 as cv
|
12 |
+
|
13 |
+
|
14 |
+
def str2bool(v):
|
15 |
+
if isinstance(v, bool):
|
16 |
+
return v
|
17 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
18 |
+
return True
|
19 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
20 |
+
return False
|
21 |
+
else:
|
22 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
23 |
+
|
24 |
+
|
25 |
+
def copyconf(default_opt, **kwargs):
|
26 |
+
conf = Namespace(**vars(default_opt))
|
27 |
+
for key in kwargs:
|
28 |
+
setattr(conf, key, kwargs[key])
|
29 |
+
return conf
|
30 |
+
|
31 |
+
|
32 |
+
def find_class_in_module(target_cls_name, module):
|
33 |
+
target_cls_name = target_cls_name.replace('_', '').lower()
|
34 |
+
clslib = importlib.import_module(module)
|
35 |
+
cls = None
|
36 |
+
for name, clsobj in clslib.__dict__.items():
|
37 |
+
if name.lower() == target_cls_name:
|
38 |
+
cls = clsobj
|
39 |
+
|
40 |
+
assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
|
41 |
+
|
42 |
+
return cls
|
43 |
+
|
44 |
+
|
45 |
+
def tensor2im(input_image, imtype=np.uint8):
|
46 |
+
""""Converts a Tensor array into a numpy image array.
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
input_image (tensor) -- the input image tensor array
|
50 |
+
imtype (type) -- the desired type of the converted numpy array
|
51 |
+
"""
|
52 |
+
if not isinstance(input_image, np.ndarray):
|
53 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
54 |
+
image_tensor = input_image.data
|
55 |
+
else:
|
56 |
+
return input_image
|
57 |
+
image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
|
58 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
59 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
60 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
61 |
+
else: # if it is a numpy array, do nothing
|
62 |
+
image_numpy = input_image
|
63 |
+
return image_numpy.astype(imtype)
|
64 |
+
|
65 |
+
|
66 |
+
def diagnose_network(net, name='network'):
|
67 |
+
"""Calculate and print the mean of average absolute(gradients)
|
68 |
+
|
69 |
+
Parameters:
|
70 |
+
net (torch network) -- Torch network
|
71 |
+
name (str) -- the name of the network
|
72 |
+
"""
|
73 |
+
mean = 0.0
|
74 |
+
count = 0
|
75 |
+
for param in net.parameters():
|
76 |
+
if param.grad is not None:
|
77 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
78 |
+
count += 1
|
79 |
+
if count > 0:
|
80 |
+
mean = mean / count
|
81 |
+
print(name)
|
82 |
+
print(mean)
|
83 |
+
|
84 |
+
|
85 |
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
86 |
+
"""Save a numpy image to the disk
|
87 |
+
|
88 |
+
Parameters:
|
89 |
+
image_numpy (numpy array) -- input numpy array
|
90 |
+
image_path (str) -- the path of the image
|
91 |
+
"""
|
92 |
+
|
93 |
+
image_pil = Image.fromarray(image_numpy)
|
94 |
+
h, w, _ = image_numpy.shape
|
95 |
+
|
96 |
+
if aspect_ratio is None:
|
97 |
+
pass
|
98 |
+
elif aspect_ratio > 1.0:
|
99 |
+
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
100 |
+
elif aspect_ratio < 1.0:
|
101 |
+
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
102 |
+
image_pil.save(image_path)
|
103 |
+
|
104 |
+
|
105 |
+
def print_numpy(x, val=True, shp=False):
|
106 |
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
107 |
+
|
108 |
+
Parameters:
|
109 |
+
val (bool) -- if print the values of the numpy array
|
110 |
+
shp (bool) -- if print the shape of the numpy array
|
111 |
+
"""
|
112 |
+
x = x.astype(np.float64)
|
113 |
+
if shp:
|
114 |
+
print('shape,', x.shape)
|
115 |
+
if val:
|
116 |
+
x = x.flatten()
|
117 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
118 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
119 |
+
|
120 |
+
|
121 |
+
def mkdirs(paths):
|
122 |
+
"""create empty directories if they don't exist
|
123 |
+
|
124 |
+
Parameters:
|
125 |
+
paths (str list) -- a list of directory paths
|
126 |
+
"""
|
127 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
128 |
+
for path in paths:
|
129 |
+
mkdir(path)
|
130 |
+
else:
|
131 |
+
mkdir(paths)
|
132 |
+
|
133 |
+
|
134 |
+
def mkdir(path):
|
135 |
+
"""create a single empty directory if it didn't exist
|
136 |
+
|
137 |
+
Parameters:
|
138 |
+
path (str) -- a single directory path
|
139 |
+
"""
|
140 |
+
if not os.path.exists(path):
|
141 |
+
os.makedirs(path)
|
142 |
+
|
143 |
+
|
144 |
+
def correct_resize_label(t, size):
|
145 |
+
device = t.device
|
146 |
+
t = t.detach().cpu()
|
147 |
+
resized = []
|
148 |
+
for i in range(t.size(0)):
|
149 |
+
one_t = t[i, :1]
|
150 |
+
one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
|
151 |
+
one_np = one_np[:, :, 0]
|
152 |
+
one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
|
153 |
+
resized_t = torch.from_numpy(np.array(one_image)).long()
|
154 |
+
resized.append(resized_t)
|
155 |
+
return torch.stack(resized, dim=0).to(device)
|
156 |
+
|
157 |
+
|
158 |
+
def correct_resize(t, size, mode=Image.BICUBIC):
|
159 |
+
device = t.device
|
160 |
+
t = t.detach().cpu()
|
161 |
+
resized = []
|
162 |
+
for i in range(t.size(0)):
|
163 |
+
one_t = t[i:i + 1]
|
164 |
+
one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
|
165 |
+
resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
|
166 |
+
resized.append(resized_t)
|
167 |
+
return torch.stack(resized, dim=0).to(device)
|
168 |
+
|
169 |
+
|
170 |
+
def expand_as_one_hot(input, C, ignore_index=None):
|
171 |
+
"""
|
172 |
+
Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
|
173 |
+
It is assumed that the batch dimension is present.
|
174 |
+
Args:
|
175 |
+
input (torch.Tensor): 3D/4D input image
|
176 |
+
C (int): number of channels/labels
|
177 |
+
ignore_index (int): ignore index to be kept during the expansion
|
178 |
+
Returns:
|
179 |
+
4D/5D output torch.Tensor (NxCxSPATIAL)
|
180 |
+
"""
|
181 |
+
# expand the input tensor to Nx1xSPATIAL before scattering
|
182 |
+
input = input.unsqueeze(1)
|
183 |
+
# create output tensor shape (NxCxSPATIAL)
|
184 |
+
shape = list(input.size())
|
185 |
+
shape[1] = C
|
186 |
+
|
187 |
+
if ignore_index is not None:
|
188 |
+
# create ignore_index mask for the result
|
189 |
+
mask = input.expand(shape) == ignore_index
|
190 |
+
# clone the src tensor and zero out ignore_index in the input
|
191 |
+
input = input.clone()
|
192 |
+
input[input == ignore_index] = 0
|
193 |
+
# scatter to get the one-hot tensor
|
194 |
+
result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
|
195 |
+
# bring back the ignore_index in the result
|
196 |
+
result[mask] = ignore_index
|
197 |
+
return result
|
198 |
+
else:
|
199 |
+
# scatter to get the one-hot tensor
|
200 |
+
return torch.zeros(shape).to(input.device).scatter_(1, input, 1)
|
201 |
+
|
202 |
+
def standardize(ref, I, threshold=50):
|
203 |
+
"""
|
204 |
+
Transform image I to standard brightness.
|
205 |
+
Modifies the luminosity channel such that a fixed percentile is saturated.
|
206 |
+
|
207 |
+
:param I: Image uint8 RGB.
|
208 |
+
:param percentile: Percentile for luminosity saturation. At least (100 - percentile)% of pixels should be fully luminous (white).
|
209 |
+
:return: Image uint8 RGB with standardized brightness.
|
210 |
+
"""
|
211 |
+
ref_m = cv.cvtColor(ref, cv.COLOR_RGB2LAB)[:, :, 0].astype(float).mean()
|
212 |
+
|
213 |
+
I_LAB = cv.cvtColor(I, cv.COLOR_RGB2LAB)
|
214 |
+
L_float = I_LAB[:, :, 0].astype(float)
|
215 |
+
tgt_m = L_float.mean()
|
216 |
+
if np.abs(tgt_m - ref_m) > threshold:
|
217 |
+
L_float = L_float - tgt_m + ref_m
|
218 |
+
I_LAB[:, :, 0] = np.clip(L_float, 0, 255).astype(np.uint8)
|
219 |
+
I = cv.cvtColor(I_LAB, cv.COLOR_LAB2RGB)
|
220 |
+
return I
|
asp/util/visualizer.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import ntpath
|
5 |
+
import time
|
6 |
+
from . import util, html
|
7 |
+
from subprocess import Popen, PIPE
|
8 |
+
|
9 |
+
if sys.version_info[0] == 2:
|
10 |
+
VisdomExceptionBase = Exception
|
11 |
+
else:
|
12 |
+
VisdomExceptionBase = ConnectionError
|
13 |
+
|
14 |
+
|
15 |
+
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
16 |
+
"""Save images to the disk.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
20 |
+
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
21 |
+
image_path (str) -- the string is used to create image paths
|
22 |
+
aspect_ratio (float) -- the aspect ratio of saved images
|
23 |
+
width (int) -- the images will be resized to width x width
|
24 |
+
|
25 |
+
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
26 |
+
"""
|
27 |
+
image_dir = webpage.get_image_dir()
|
28 |
+
short_path = ntpath.basename(image_path[0])
|
29 |
+
name = os.path.splitext(short_path)[0]
|
30 |
+
|
31 |
+
webpage.add_header(name)
|
32 |
+
ims, txts, links = [], [], []
|
33 |
+
|
34 |
+
for label, im_data in visuals.items():
|
35 |
+
im = util.tensor2im(im_data)
|
36 |
+
image_name = '%s/%s.png' % (label, name)
|
37 |
+
os.makedirs(os.path.join(image_dir, label), exist_ok=True)
|
38 |
+
save_path = os.path.join(image_dir, image_name)
|
39 |
+
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
|
40 |
+
ims.append(image_name)
|
41 |
+
txts.append(label)
|
42 |
+
links.append(image_name)
|
43 |
+
webpage.add_images(ims, txts, links, width=width)
|
44 |
+
|
45 |
+
|
46 |
+
class Visualizer():
|
47 |
+
"""This class includes several functions that can display/save images and print/save logging information.
|
48 |
+
|
49 |
+
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, opt):
|
53 |
+
"""Initialize the Visualizer class
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
57 |
+
Step 1: Cache the training/test options
|
58 |
+
Step 2: connect to a visdom server
|
59 |
+
Step 3: create an HTML object for saveing HTML filters
|
60 |
+
Step 4: create a logging file to store training losses
|
61 |
+
"""
|
62 |
+
self.opt = opt # cache the option
|
63 |
+
if opt.display_id is None:
|
64 |
+
self.display_id = np.random.randint(100000) * 10 # just a random display id
|
65 |
+
else:
|
66 |
+
self.display_id = opt.display_id
|
67 |
+
self.use_html = opt.isTrain and not opt.no_html
|
68 |
+
self.win_size = opt.display_winsize
|
69 |
+
self.name = opt.name
|
70 |
+
self.port = opt.display_port
|
71 |
+
self.saved = False
|
72 |
+
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
73 |
+
import visdom
|
74 |
+
self.plot_data = {}
|
75 |
+
self.ncols = opt.display_ncols
|
76 |
+
if "tensorboard_base_url" not in os.environ:
|
77 |
+
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
78 |
+
else:
|
79 |
+
self.vis = visdom.Visdom(port=2004,
|
80 |
+
base_url=os.environ['tensorboard_base_url'] + '/visdom')
|
81 |
+
if not self.vis.check_connection():
|
82 |
+
self.create_visdom_connections()
|
83 |
+
|
84 |
+
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
85 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
86 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
87 |
+
print('create web directory %s...' % self.web_dir)
|
88 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
89 |
+
# create a logging file to store training losses
|
90 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
91 |
+
with open(self.log_name, "a") as log_file:
|
92 |
+
now = time.strftime("%c")
|
93 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
94 |
+
|
95 |
+
def reset(self):
|
96 |
+
"""Reset the self.saved status"""
|
97 |
+
self.saved = False
|
98 |
+
|
99 |
+
def create_visdom_connections(self):
|
100 |
+
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
101 |
+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
102 |
+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
103 |
+
print('Command: %s' % cmd)
|
104 |
+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
105 |
+
|
106 |
+
def display_current_results(self, visuals, epoch, save_result):
|
107 |
+
"""Display current results on visdom; save current results to an HTML file.
|
108 |
+
|
109 |
+
Parameters:
|
110 |
+
visuals (OrderedDict) - - dictionary of images to display or save
|
111 |
+
epoch (int) - - the current epoch
|
112 |
+
save_result (bool) - - if save the current results to an HTML file
|
113 |
+
"""
|
114 |
+
if self.display_id > 0: # show images in the browser using visdom
|
115 |
+
ncols = self.ncols
|
116 |
+
if ncols > 0: # show all the images in one visdom panel
|
117 |
+
ncols = min(ncols, len(visuals))
|
118 |
+
h, w = next(iter(visuals.values())).shape[:2]
|
119 |
+
table_css = """<style>
|
120 |
+
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
121 |
+
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
122 |
+
</style>""" % (w, h) # create a table css
|
123 |
+
# create a table of images.
|
124 |
+
title = self.name
|
125 |
+
label_html = ''
|
126 |
+
label_html_row = ''
|
127 |
+
images = []
|
128 |
+
idx = 0
|
129 |
+
for label, image in visuals.items():
|
130 |
+
image_numpy = util.tensor2im(image)
|
131 |
+
label_html_row += '<td>%s</td>' % label
|
132 |
+
images.append(image_numpy.transpose([2, 0, 1]))
|
133 |
+
idx += 1
|
134 |
+
if idx % ncols == 0:
|
135 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
136 |
+
label_html_row = ''
|
137 |
+
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
138 |
+
while idx % ncols != 0:
|
139 |
+
images.append(white_image)
|
140 |
+
label_html_row += '<td></td>'
|
141 |
+
idx += 1
|
142 |
+
if label_html_row != '':
|
143 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
144 |
+
try:
|
145 |
+
self.vis.images(images, ncols, 2, self.display_id + 1,
|
146 |
+
None, dict(title=title + ' images'))
|
147 |
+
label_html = '<table>%s</table>' % label_html
|
148 |
+
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
149 |
+
opts=dict(title=title + ' labels'))
|
150 |
+
except VisdomExceptionBase:
|
151 |
+
self.create_visdom_connections()
|
152 |
+
|
153 |
+
else: # show each image in a separate visdom panel;
|
154 |
+
idx = 1
|
155 |
+
try:
|
156 |
+
for label, image in visuals.items():
|
157 |
+
image_numpy = util.tensor2im(image)
|
158 |
+
self.vis.image(
|
159 |
+
image_numpy.transpose([2, 0, 1]),
|
160 |
+
self.display_id + idx,
|
161 |
+
None,
|
162 |
+
dict(title=label)
|
163 |
+
)
|
164 |
+
idx += 1
|
165 |
+
except VisdomExceptionBase:
|
166 |
+
self.create_visdom_connections()
|
167 |
+
|
168 |
+
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
169 |
+
self.saved = True
|
170 |
+
# save images to the disk
|
171 |
+
for label, image in visuals.items():
|
172 |
+
image_numpy = util.tensor2im(image)
|
173 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
174 |
+
util.save_image(image_numpy, img_path)
|
175 |
+
|
176 |
+
# update website
|
177 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
|
178 |
+
for n in range(epoch, 0, -1):
|
179 |
+
webpage.add_header('epoch [%d]' % n)
|
180 |
+
ims, txts, links = [], [], []
|
181 |
+
|
182 |
+
for label, image_numpy in visuals.items():
|
183 |
+
image_numpy = util.tensor2im(image)
|
184 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
185 |
+
ims.append(img_path)
|
186 |
+
txts.append(label)
|
187 |
+
links.append(img_path)
|
188 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
189 |
+
webpage.save()
|
190 |
+
|
191 |
+
def plot_current_losses(self, epoch, counter_ratio, losses):
|
192 |
+
"""display the current losses on visdom display: dictionary of error labels and values
|
193 |
+
|
194 |
+
Parameters:
|
195 |
+
epoch (int) -- current epoch
|
196 |
+
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
197 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
198 |
+
"""
|
199 |
+
if len(losses) == 0:
|
200 |
+
return
|
201 |
+
|
202 |
+
plot_name = '_'.join(list(losses.keys()))
|
203 |
+
|
204 |
+
if plot_name not in self.plot_data:
|
205 |
+
self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
206 |
+
|
207 |
+
plot_data = self.plot_data[plot_name]
|
208 |
+
plot_id = list(self.plot_data.keys()).index(plot_name)
|
209 |
+
|
210 |
+
plot_data['X'].append(epoch + counter_ratio)
|
211 |
+
plot_data['Y'].append([losses[k] for k in plot_data['legend']])
|
212 |
+
try:
|
213 |
+
self.vis.line(
|
214 |
+
X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
|
215 |
+
Y=np.array(plot_data['Y']),
|
216 |
+
opts={
|
217 |
+
'title': self.name,
|
218 |
+
'legend': plot_data['legend'],
|
219 |
+
'xlabel': 'epoch',
|
220 |
+
'ylabel': 'loss'},
|
221 |
+
win=self.display_id - plot_id)
|
222 |
+
except VisdomExceptionBase:
|
223 |
+
self.create_visdom_connections()
|
224 |
+
|
225 |
+
# losses: same format as |losses| of plot_current_losses
|
226 |
+
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
227 |
+
"""print current losses on console; also save the losses to the disk
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
epoch (int) -- current epoch
|
231 |
+
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
232 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
233 |
+
t_comp (float) -- computational time per data point (normalized by batch_size)
|
234 |
+
t_data (float) -- data loading time per data point (normalized by batch_size)
|
235 |
+
"""
|
236 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
237 |
+
for k, v in losses.items():
|
238 |
+
message += '%s: %.3f ' % (k, v)
|
239 |
+
|
240 |
+
print(message) # print the message
|
241 |
+
with open(self.log_name, "a") as log_file:
|
242 |
+
log_file.write('%s\n' % message) # save the message
|
main.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
from torchvision.transforms.functional import to_pil_image
|
7 |
+
|
8 |
+
from types import SimpleNamespace
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from asp.models.cpt_model import CPTModel
|
12 |
+
from asp.data.base_dataset import get_transform
|
13 |
+
from asp.util.general_utils import parse_args
|
14 |
+
|
15 |
+
def transform_with_seed(input_img, transform, seed=123456):
|
16 |
+
random.seed(seed)
|
17 |
+
torch.manual_seed(seed)
|
18 |
+
return transform(input_img)
|
19 |
+
|
20 |
+
def convert_he2ihc(input_he_image_path):
|
21 |
+
input_img = Image.open(input_he_image_path).convert('RGB')
|
22 |
+
|
23 |
+
opt = SimpleNamespace(
|
24 |
+
gpu_ids=[0],
|
25 |
+
isTrain=False,
|
26 |
+
checkpoints_dir="../../checkpoints",
|
27 |
+
# name="ASP_pretrained/BCI_her2_lambda_linear",
|
28 |
+
name="ASP_pretrained/BCI_her2_zero_uniform",
|
29 |
+
preprocess="crop",
|
30 |
+
nce_layers="0,4,8,12,16",
|
31 |
+
nce_idt=False,
|
32 |
+
input_nc=3,
|
33 |
+
output_nc=3,
|
34 |
+
ngf=64,
|
35 |
+
netG="resnet_6blocks",
|
36 |
+
normG="instance",
|
37 |
+
no_dropout=True,
|
38 |
+
init_type="xavier",
|
39 |
+
init_gain=0.02,
|
40 |
+
no_antialias=False,
|
41 |
+
no_antialias_up=False,
|
42 |
+
weight_norm="spectral",
|
43 |
+
netF="mlp_sample",
|
44 |
+
netF_nc=256,
|
45 |
+
no_flip=True,
|
46 |
+
load_size=1024,
|
47 |
+
crop_size=1024,
|
48 |
+
direction="AtoB",
|
49 |
+
flip_equivariance=False,
|
50 |
+
epoch="latest",
|
51 |
+
verbose=True
|
52 |
+
)
|
53 |
+
model = CPTModel(opt)
|
54 |
+
|
55 |
+
transform = get_transform(opt)
|
56 |
+
|
57 |
+
model.setup(opt)
|
58 |
+
model.parallelize()
|
59 |
+
model.eval()
|
60 |
+
|
61 |
+
A = transform_with_seed(input_img, transform)
|
62 |
+
model.set_input({
|
63 |
+
"A": A.unsqueeze(0),
|
64 |
+
"A_paths": input_he_image_path,
|
65 |
+
"B": A.unsqueeze(0),
|
66 |
+
"B_paths": input_he_image_path,
|
67 |
+
})
|
68 |
+
model.test()
|
69 |
+
visuals = model.get_current_visuals()
|
70 |
+
|
71 |
+
output_img = to_pil_image(visuals['fake_B'].detach().cpu().squeeze(0))
|
72 |
+
print("np.shape(output_img)", np.shape(output_img))
|
73 |
+
|
74 |
+
return output_img
|
75 |
+
|
76 |
+
def main():
|
77 |
+
demo = gr.Interface(
|
78 |
+
fn=convert_he2ihc,
|
79 |
+
inputs=gr.Image(type="filepath"),
|
80 |
+
outputs=gr.Image(),
|
81 |
+
title="H&E to IHC, BIC HER2"
|
82 |
+
)
|
83 |
+
|
84 |
+
demo.launch()
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
args = parse_args(main)
|
88 |
+
main(**vars(args))
|
89 |
+
|
90 |
+
# python main.py -i ../../data/BCI_dataset/BCI_dataset/HE/test/00003_test_3+.png
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dominate
|
2 |
+
packaging
|
3 |
+
opencv-python
|
4 |
+
GPUtil
|
5 |
+
|
6 |
+
gradio
|