antoinedelplace commited on
Commit
207ef6f
·
0 Parent(s):

First commit

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Byte-compiled
2
+ __pycache__/
3
+
4
+ # Environment
5
+ /venv/
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # H&E to IHC translation
2
+ Based on Adaptive Supervised PatchNCE Loss for Learning H&E-to-IHC Stain Translation with Inconsistent Groundtruth Image Pairs (MICCAI 2023)
3
+
4
+ Original folder: [lifangda01/AdaptiveSupervisedPatchNCE](https://github.com/lifangda01/AdaptiveSupervisedPatchNCE)
5
+
6
+ Original paper: [![arXiv](https://img.shields.io/badge/arXiv-2303.06193-00ff00.svg)](https://arxiv.org/pdf/2303.06193)
asp/data/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from asp.data.base_dataset import BaseDataset
16
+
17
+
18
+ def find_dataset_using_name(dataset_name):
19
+ """Import the module "data/[dataset_name]_dataset.py".
20
+
21
+ In the file, the class called DatasetNameDataset() will
22
+ be instantiated. It has to be a subclass of BaseDataset,
23
+ and it is case-insensitive.
24
+ """
25
+ dataset_filename = "data." + dataset_name + "_dataset"
26
+ datasetlib = importlib.import_module(dataset_filename)
27
+
28
+ dataset = None
29
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30
+ for name, cls in datasetlib.__dict__.items():
31
+ if name.lower() == target_dataset_name.lower() \
32
+ and issubclass(cls, BaseDataset):
33
+ dataset = cls
34
+
35
+ if dataset is None:
36
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37
+
38
+ return dataset
39
+
40
+
41
+ def get_option_setter(dataset_name):
42
+ """Return the static method <modify_commandline_options> of the dataset class."""
43
+ dataset_class = find_dataset_using_name(dataset_name)
44
+ return dataset_class.modify_commandline_options
45
+
46
+
47
+ def create_dataset(opt):
48
+ """Create a dataset given the option.
49
+
50
+ This function wraps the class CustomDatasetDataLoader.
51
+ This is the main interface between this package and 'train.py'/'test.py'
52
+
53
+ Example:
54
+ >>> from data import create_dataset
55
+ >>> dataset = create_dataset(opt)
56
+ """
57
+ data_loader = CustomDatasetDataLoader(opt)
58
+ dataset = data_loader.load_data()
59
+ return dataset
60
+
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ print("dataset [%s] was created" % type(self.dataset).__name__)
75
+ self.dataloader = torch.utils.data.DataLoader(
76
+ self.dataset,
77
+ batch_size=opt.batch_size,
78
+ shuffle=not opt.serial_batches,
79
+ num_workers=int(opt.num_threads),
80
+ drop_last=True if opt.isTrain else False,
81
+ )
82
+
83
+ def set_epoch(self, epoch):
84
+ self.dataset.current_epoch = epoch
85
+
86
+ def load_data(self):
87
+ return self
88
+
89
+ def __len__(self):
90
+ """Return the number of data in the dataset"""
91
+ return min(len(self.dataset), self.opt.max_dataset_size)
92
+
93
+ def __iter__(self):
94
+ """Return a batch of data"""
95
+ for i, data in enumerate(self.dataloader):
96
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
97
+ break
98
+ yield data
asp/data/aligned_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import numpy as np
3
+ import torch
4
+ import json
5
+
6
+ from data.base_dataset import BaseDataset, get_transform
7
+ from data.image_folder import make_dataset
8
+ from PIL import Image
9
+ import random
10
+ import util.util as util
11
+
12
+
13
+ class AlignedDataset(BaseDataset):
14
+ """
15
+ This dataset class can load aligned/paired datasets.
16
+
17
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
18
+ and from domain B '/path/to/data/trainB' respectively.
19
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
20
+ Similarly, you need to prepare two directories:
21
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
22
+ """
23
+
24
+ def __init__(self, opt):
25
+ """Initialize this dataset class.
26
+
27
+ Parameters:
28
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
29
+ """
30
+ BaseDataset.__init__(self, opt)
31
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
32
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
33
+
34
+ if opt.phase == "test" and not os.path.exists(self.dir_A) \
35
+ and os.path.exists(os.path.join(opt.dataroot, "valA")):
36
+ self.dir_A = os.path.join(opt.dataroot, "valA")
37
+ self.dir_B = os.path.join(opt.dataroot, "valB")
38
+
39
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
40
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
41
+
42
+ self.A_size = len(self.A_paths) # get the size of dataset A
43
+ self.B_size = len(self.B_paths) # get the size of dataset B
44
+ assert self.A_size == self.B_size
45
+
46
+ def __getitem__(self, index):
47
+ """Return a data point and its metadata information.
48
+
49
+ Parameters:
50
+ index (int) -- a random integer for data indexing
51
+
52
+ Returns a dictionary that contains A, B, A_paths and B_paths
53
+ A (tensor) -- an image in the input domain
54
+ B (tensor) -- its corresponding image in the target domain
55
+ A_paths (str) -- image paths
56
+ B_paths (str) -- image paths
57
+ """
58
+ if self.opt.serial_batches: # make sure index is within then range
59
+ index_B = index % self.B_size
60
+ else: # randomize the index for domain B to avoid fixed pairs.
61
+ index = random.randint(0, self.A_size - 1)
62
+ index_B = index % self.B_size
63
+
64
+ A_path = self.A_paths[index] # make sure index is within then range
65
+ B_path = self.B_paths[index_B]
66
+
67
+ assert A_path == B_path.replace('trainB', 'trainA').replace('valB', 'valA').replace('testB', 'testA')
68
+
69
+ A_img = Image.open(A_path).convert('RGB')
70
+ B_img = Image.open(B_path).convert('RGB')
71
+
72
+ # Apply image transformation
73
+ # For CUT/FastCUT mode, if in finetuning phase (learning rate is decaying),
74
+ # do not perform resize-crop data augmentation of CycleGAN.
75
+ is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
76
+ modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
77
+ transform = get_transform(modified_opt)
78
+
79
+ # FDL: synchronize transforms
80
+ seed = np.random.randint(2147483647) # make a seed with numpy generator
81
+ random.seed(seed) # apply this seed to img tranfsorms
82
+ torch.manual_seed(seed) # needed for torchvision 0.7
83
+ A = transform(A_img)
84
+ random.seed(seed) # apply this seed to target tranfsorms
85
+ torch.manual_seed(seed) # needed for torchvision 0.7
86
+ B = transform(B_img)
87
+
88
+ return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
89
+
90
+ def __len__(self):
91
+ """Return the total number of images in the dataset.
92
+
93
+ As we have two datasets with potentially different number of images,
94
+ we take a maximum of
95
+ """
96
+ return max(self.A_size, self.B_size)
asp/data/base_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_params(opt, size):
65
+ w, h = size
66
+ new_h = h
67
+ new_w = w
68
+ if opt.preprocess == 'resize_and_crop':
69
+ new_h = new_w = opt.load_size
70
+ elif opt.preprocess == 'scale_width_and_crop':
71
+ new_w = opt.load_size
72
+ new_h = opt.load_size * h // w
73
+
74
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76
+
77
+ flip = random.random() > 0.5
78
+
79
+ return {'crop_pos': (x, y), 'flip': flip}
80
+
81
+
82
+ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83
+ transform_list = []
84
+ if grayscale:
85
+ transform_list.append(transforms.Grayscale(1))
86
+ if 'fixsize' in opt.preprocess:
87
+ transform_list.append(transforms.Resize(params["size"], method))
88
+ if 'resize' in opt.preprocess:
89
+ osize = [opt.load_size, opt.load_size]
90
+ if "gta2cityscapes" in opt.dataroot:
91
+ osize[0] = opt.load_size // 2
92
+ transform_list.append(transforms.Resize(osize, method))
93
+ elif 'scale_width' in opt.preprocess:
94
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
95
+ elif 'scale_shortside' in opt.preprocess:
96
+ transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
97
+
98
+ if 'zoom' in opt.preprocess:
99
+ if params is None:
100
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
101
+ else:
102
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
103
+
104
+ if 'crop' in opt.preprocess:
105
+ if params is None or 'crop_pos' not in params:
106
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
107
+ else:
108
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
109
+
110
+ if 'patch' in opt.preprocess:
111
+ transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
112
+
113
+ if 'trim' in opt.preprocess:
114
+ transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
115
+
116
+ # if opt.preprocess == 'none':
117
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
118
+
119
+ if not opt.no_flip:
120
+ if params is None or 'flip' not in params:
121
+ transform_list.append(transforms.RandomHorizontalFlip())
122
+ elif 'flip' in params:
123
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
124
+
125
+ if convert:
126
+ transform_list += [transforms.ToTensor()]
127
+ if grayscale:
128
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
129
+ else:
130
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
131
+ return transforms.Compose(transform_list)
132
+
133
+
134
+ def __make_power_2(img, base, method=Image.BICUBIC):
135
+ ow, oh = img.size
136
+ h = int(round(oh / base) * base)
137
+ w = int(round(ow / base) * base)
138
+ if h == oh and w == ow:
139
+ return img
140
+
141
+ return img.resize((w, h), method)
142
+
143
+
144
+ def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
145
+ if factor is None:
146
+ zoom_level = np.random.uniform(0.8, 1.0, size=[2])
147
+ else:
148
+ zoom_level = (factor[0], factor[1])
149
+ iw, ih = img.size
150
+ zoomw = max(crop_width, iw * zoom_level[0])
151
+ zoomh = max(crop_width, ih * zoom_level[1])
152
+ img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
153
+ return img
154
+
155
+
156
+ def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
157
+ ow, oh = img.size
158
+ shortside = min(ow, oh)
159
+ if shortside >= target_width:
160
+ return img
161
+ else:
162
+ scale = target_width / shortside
163
+ return img.resize((round(ow * scale), round(oh * scale)), method)
164
+
165
+
166
+ def __trim(img, trim_width):
167
+ ow, oh = img.size
168
+ if ow > trim_width:
169
+ xstart = np.random.randint(ow - trim_width)
170
+ xend = xstart + trim_width
171
+ else:
172
+ xstart = 0
173
+ xend = ow
174
+ if oh > trim_width:
175
+ ystart = np.random.randint(oh - trim_width)
176
+ yend = ystart + trim_width
177
+ else:
178
+ ystart = 0
179
+ yend = oh
180
+ return img.crop((xstart, ystart, xend, yend))
181
+
182
+
183
+ def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
184
+ ow, oh = img.size
185
+ if ow == target_width and oh >= crop_width:
186
+ return img
187
+ w = target_width
188
+ h = int(max(target_width * oh / ow, crop_width))
189
+ return img.resize((w, h), method)
190
+
191
+
192
+ def __crop(img, pos, size):
193
+ ow, oh = img.size
194
+ x1, y1 = pos
195
+ tw = th = size
196
+ if (ow > tw or oh > th):
197
+ return img.crop((x1, y1, x1 + tw, y1 + th))
198
+ return img
199
+
200
+
201
+ def __patch(img, index, size):
202
+ ow, oh = img.size
203
+ nw, nh = ow // size, oh // size
204
+ roomx = ow - nw * size
205
+ roomy = oh - nh * size
206
+ startx = np.random.randint(int(roomx) + 1)
207
+ starty = np.random.randint(int(roomy) + 1)
208
+
209
+ index = index % (nw * nh)
210
+ ix = index // nh
211
+ iy = index % nh
212
+ gridx = startx + ix * size
213
+ gridy = starty + iy * size
214
+ return img.crop((gridx, gridy, gridx + size, gridy + size))
215
+
216
+
217
+ def __flip(img, flip):
218
+ if flip:
219
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
220
+ return img
221
+
222
+
223
+ def __print_size_warning(ow, oh, w, h):
224
+ """Print warning information about image size(only print once)"""
225
+ if not hasattr(__print_size_warning, 'has_printed'):
226
+ print("The image size needs to be a multiple of 4. "
227
+ "The loaded image size was (%d, %d), so it was adjusted to "
228
+ "(%d, %d). This adjustment will be done to all images "
229
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
230
+ __print_size_warning.has_printed = True
asp/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
asp/experiments/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+
4
+
5
+ def find_launcher_using_name(launcher_name):
6
+ # cur_dir = os.path.dirname(os.path.abspath(__file__))
7
+ # pythonfiles = glob.glob(cur_dir + '/**/*.py')
8
+ launcher_filename = "experiments.{}_launcher".format(launcher_name)
9
+ launcherlib = importlib.import_module(launcher_filename)
10
+
11
+ # In the file, the class called LauncherNameLauncher() will
12
+ # be instantiated. It has to be a subclass of BaseLauncher,
13
+ # and it is case-insensitive.
14
+ launcher = None
15
+ target_launcher_name = launcher_name.replace('_', '') + 'launcher'
16
+ for name, cls in launcherlib.__dict__.items():
17
+ if name.lower() == target_launcher_name.lower():
18
+ launcher = cls
19
+
20
+ if launcher is None:
21
+ raise ValueError("In %s.py, there should be a subclass of BaseLauncher "
22
+ "with class name that matches %s in lowercase." %
23
+ (launcher_filename, target_launcher_name))
24
+
25
+ return launcher
26
+
27
+
28
+ if __name__ == "__main__":
29
+ import sys
30
+ import pickle
31
+
32
+ assert len(sys.argv) >= 3
33
+
34
+ name = sys.argv[1]
35
+ Launcher = find_launcher_using_name(name)
36
+
37
+ cache = "/tmp/tmux_launcher/{}".format(name)
38
+ if os.path.isfile(cache):
39
+ instance = pickle.load(open(cache, 'r'))
40
+ else:
41
+ instance = Launcher()
42
+
43
+ cmd = sys.argv[2]
44
+ if cmd == "launch":
45
+ instance.launch()
46
+ elif cmd == "stop":
47
+ instance.stop()
48
+ elif cmd == "send":
49
+ expid = int(sys.argv[3])
50
+ cmd = int(sys.argv[4])
51
+ instance.send_command(expid, cmd)
52
+
53
+ os.makedirs("/tmp/tmux_launcher/", exist_ok=True)
54
+ pickle.dump(instance, open(cache, 'w'))
asp/experiments/__main__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+
4
+
5
+ def find_launcher_using_name(launcher_name):
6
+ # cur_dir = os.path.dirname(os.path.abspath(__file__))
7
+ # pythonfiles = glob.glob(cur_dir + '/**/*.py')
8
+ launcher_filename = "experiments.{}_launcher".format(launcher_name)
9
+ launcherlib = importlib.import_module(launcher_filename)
10
+
11
+ # In the file, the class called LauncherNameLauncher() will
12
+ # be instantiated. It has to be a subclass of BaseLauncher,
13
+ # and it is case-insensitive.
14
+ launcher = None
15
+ # target_launcher_name = launcher_name.replace('_', '') + 'launcher'
16
+ for name, cls in launcherlib.__dict__.items():
17
+ if name.lower() == "launcher":
18
+ launcher = cls
19
+
20
+ if launcher is None:
21
+ raise ValueError("In %s.py, there should be a class named Launcher")
22
+
23
+ return launcher
24
+
25
+
26
+ if __name__ == "__main__":
27
+ import argparse
28
+
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('name')
31
+ parser.add_argument('cmd')
32
+ parser.add_argument('id', nargs='+', type=str)
33
+ parser.add_argument('--mode', default=None)
34
+ parser.add_argument('--which_epoch', default=None)
35
+ parser.add_argument('--continue_train', action='store_true')
36
+ parser.add_argument('--subdir', default='')
37
+ parser.add_argument('--title', default='')
38
+ parser.add_argument('--gpu_id', default=None, type=int)
39
+ parser.add_argument('--phase', default='test')
40
+
41
+ opt = parser.parse_args()
42
+
43
+ name = opt.name
44
+ Launcher = find_launcher_using_name(name)
45
+
46
+ instance = Launcher()
47
+
48
+ cmd = opt.cmd
49
+ ids = 'all' if 'all' in opt.id else [int(i) for i in opt.id]
50
+ if cmd == "launch":
51
+ instance.launch(ids, continue_train=opt.continue_train)
52
+ elif cmd == "stop":
53
+ instance.stop()
54
+ elif cmd == "send":
55
+ assert False
56
+ elif cmd == "close":
57
+ instance.close()
58
+ elif cmd == "dry":
59
+ instance.dry()
60
+ elif cmd == "relaunch":
61
+ instance.close()
62
+ instance.launch(ids, continue_train=opt.continue_train)
63
+ elif cmd == "run" or cmd == "train":
64
+ assert len(ids) == 1, '%s is invalid for run command' % (' '.join(opt.id))
65
+ expid = ids[0]
66
+ instance.run_command(instance.commands(), expid,
67
+ continue_train=opt.continue_train,
68
+ gpu_id=opt.gpu_id)
69
+ elif cmd == 'launch_test':
70
+ instance.launch(ids, test=True)
71
+ elif cmd == "run_test" or cmd == "test":
72
+ test_commands = instance.test_commands()
73
+ if ids == "all":
74
+ ids = list(range(len(test_commands)))
75
+ for expid in ids:
76
+ instance.run_command(test_commands, expid, opt.which_epoch,
77
+ gpu_id=opt.gpu_id)
78
+ if expid < len(ids) - 1:
79
+ os.system("sleep 5s")
80
+ elif cmd == "print_names":
81
+ instance.print_names(ids, test=False)
82
+ elif cmd == "print_test_names":
83
+ instance.print_names(ids, test=True)
84
+ elif cmd == "create_comparison_html":
85
+ instance.create_comparison_html(name, ids, opt.subdir, opt.title, opt.phase)
86
+ else:
87
+ raise ValueError("Command not recognized")
asp/experiments/mist_launcher.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .tmux_launcher import Options, TmuxLauncher
2
+
3
+
4
+ class Launcher(TmuxLauncher):
5
+ def common_options(self):
6
+ return [
7
+ # Command 0
8
+ Options(
9
+ dataroot="../data/BCI_dataset",
10
+ name="mist_her2_lambda_linear",
11
+ checkpoints_dir='../checkpoints',
12
+ model='cpt',
13
+ CUT_mode="FastCUT",
14
+
15
+ n_epochs=30, # number of epochs with the initial learning rate
16
+ n_epochs_decay=10, # number of epochs to linearly decay learning rate to zero
17
+
18
+ netD='n_layers', # ['basic', 'n_layers, 'pixel', 'patch'], 'specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
19
+ ndf=32,
20
+ netG='resnet_6blocks', # ['resnet_9blocks', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], 'specify generator architecture')
21
+ n_layers_D=5, # 'only used if netD==n_layers'
22
+ normG='instance', # ['instance, 'batch, 'none'], 'instance normalization or batch normalization for G')
23
+ normD='instance', # ['instance, 'batch, 'none'], 'instance normalization or batch normalization for D')
24
+ weight_norm='spectral',
25
+
26
+ lambda_GAN=1.0, # weight for GAN loss:GAN(G(X))
27
+ lambda_NCE=10.0, # weight for NCE loss: NCE(G(X), X)
28
+ nce_layers='0,4,8,12,16',
29
+ nce_T=0.07,
30
+ num_patches=256,
31
+
32
+ # FDL:
33
+ lambda_gp=10.0,
34
+ gp_weights='[0.015625,0.03125,0.0625,0.125,0.25,1.0]',
35
+ lambda_asp=10.0, # weight for NCE loss: NCE(G(X), X)
36
+ asp_loss_mode='lambda_linear',
37
+
38
+ dataset_mode='aligned', # chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
39
+ direction='AtoB',
40
+ # serial_batches='', # if true, takes images in order to make batches, otherwise takes them randomly
41
+ num_threads=15, # '# threads for loading data')
42
+ batch_size=1, # 'input batch size')
43
+ load_size=1024, # 'scale images to this size')
44
+ crop_size=512, # 'then crop to this size')
45
+ preprocess='crop', # ='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
46
+ # no_flip='',
47
+ flip_equivariance=False,
48
+ display_winsize=512, # display window size for both visdom and HTML
49
+ # display_id=0,
50
+ update_html_freq=100,
51
+ save_epoch_freq=5,
52
+ # print_freq=10,
53
+ ),
54
+ ]
55
+
56
+ def commands(self):
57
+ return ["python train.py " + str(opt) for opt in self.common_options()]
58
+
59
+ def test_commands(self):
60
+ opts = self.common_options()
61
+ phase = 'val'
62
+ for opt in opts:
63
+ opt.set(crop_size=1024, num_test=1000)
64
+ opt.remove('n_epochs', 'n_epochs_decay', 'update_html_freq',
65
+ 'save_epoch_freq', 'continue_train', 'epoch_count')
66
+ return ["python test.py " + str(opt.set(phase=phase)) for opt in opts]
asp/experiments/pretrained_launcher.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .tmux_launcher import Options, TmuxLauncher
2
+
3
+
4
+ class Launcher(TmuxLauncher):
5
+ def common_options(self):
6
+ return [
7
+ # Command 0
8
+ Options(
9
+ # NOTE: download the resized (and compressed) val set from
10
+ # http://efrosgans.eecs.berkeley.edu/CUT/datasets/cityscapes_val_for_CUT.tar
11
+ dataroot="datasets/cityscapes/cityscapes_val/",
12
+ direction="BtoA",
13
+ phase="val",
14
+ name="cityscapes_cut_pretrained",
15
+ CUT_mode="CUT",
16
+ ),
17
+
18
+ # Command 1
19
+ Options(
20
+ dataroot="./datasets/cityscapes_unaligned/cityscapes/",
21
+ direction="BtoA",
22
+ name="cityscapes_fastcut_pretrained",
23
+ CUT_mode="FastCUT",
24
+ ),
25
+
26
+ # Command 2
27
+ Options(
28
+ dataroot="./datasets/horse2zebra/",
29
+ name="horse2zebra_cut_pretrained",
30
+ CUT_mode="CUT"
31
+ ),
32
+
33
+ # Command 3
34
+ Options(
35
+ dataroot="./datasets/horse2zebra/",
36
+ name="horse2zebra_fastcut_pretrained",
37
+ CUT_mode="FastCUT",
38
+ ),
39
+
40
+ # Command 4
41
+ Options(
42
+ dataroot="/mnt/cloudNAS3/fangda/CycleGANData/dog2cat",
43
+ name="cat2dog_cut_pretrained",
44
+ CUT_mode="CUT"
45
+ ),
46
+
47
+ # Command 5
48
+ Options(
49
+ dataroot="./datasets/afhq/cat2dog/",
50
+ name="cat2dog_fastcut_pretrained",
51
+ CUT_mode="FastCUT",
52
+ ),
53
+
54
+
55
+ ]
56
+
57
+ def commands(self):
58
+ return ["python train.py " + str(opt) for opt in self.common_options()]
59
+
60
+ def test_commands(self):
61
+ return ["python test.py " + str(opt.set(num_test=500)) for opt in self.common_options()]
asp/experiments/tmux_launcher.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ experiment launcher using tmux panes
3
+ """
4
+ import os
5
+ import math
6
+ import GPUtil
7
+ import re
8
+
9
+ available_gpu_devices = None
10
+
11
+
12
+ class Options():
13
+ def __init__(self, *args, **kwargs):
14
+ self.args = []
15
+ self.kvs = {"gpu_ids": "0"}
16
+ self.set(*args, **kwargs)
17
+
18
+ def set(self, *args, **kwargs):
19
+ for a in args:
20
+ self.args.append(a)
21
+ for k, v in kwargs.items():
22
+ self.kvs[k] = v
23
+
24
+ return self
25
+
26
+ def remove(self, *args):
27
+ for a in args:
28
+ if a in self.args:
29
+ self.args.remove(a)
30
+ if a in self.kvs:
31
+ del self.kvs[a]
32
+
33
+ return self
34
+
35
+ def update(self, opt):
36
+ self.args += opt.args
37
+ self.kvs.update(opt.kvs)
38
+ return self
39
+
40
+ def __str__(self):
41
+ final = " ".join(self.args)
42
+ for k, v in self.kvs.items():
43
+ final += " --{} {}".format(k, v)
44
+
45
+ return final
46
+
47
+ def clone(self):
48
+ opt = Options()
49
+ opt.args = self.args.copy()
50
+ opt.kvs = self.kvs.copy()
51
+ return opt
52
+
53
+
54
+ def grab_pattern(pattern, text):
55
+ found = re.search(pattern, text)
56
+ if found is not None:
57
+ return found[1]
58
+ else:
59
+ None
60
+
61
+
62
+ # http://code.activestate.com/recipes/252177-find-the-common-beginning-in-a-list-of-strings/
63
+ def findcommonstart(strlist):
64
+ prefix_len = ([min([x[0] == elem for elem in x])
65
+ for x in zip(*strlist)] + [0]).index(0)
66
+ prefix_len = max(1, prefix_len - 4)
67
+ return strlist[0][:prefix_len]
68
+
69
+
70
+ class TmuxLauncher():
71
+ def __init__(self):
72
+ super().__init__()
73
+ self.tmux_prepared = False
74
+
75
+ def prepare_tmux_panes(self, num_experiments, dry=False):
76
+ self.pane_per_window = 1
77
+ self.n_windows = int(math.ceil(num_experiments / self.pane_per_window))
78
+ print('preparing {} tmux panes'.format(num_experiments))
79
+ for w in range(self.n_windows):
80
+ if dry:
81
+ continue
82
+ window_name = "experiments_{}".format(w)
83
+ os.system("tmux new-window -n {}".format(window_name))
84
+ self.tmux_prepared = True
85
+
86
+ def refine_command(self, command, which_epoch, continue_train, gpu_id=None):
87
+ command = str(command)
88
+ if "--gpu_ids" in command:
89
+ gpu_ids = re.search(r'--gpu_ids ([\d,?]+)', command)[1]
90
+ else:
91
+ gpu_ids = "0"
92
+
93
+ gpu_ids = gpu_ids.split(",")
94
+ num_gpus = len(gpu_ids)
95
+ global available_gpu_devices
96
+ if available_gpu_devices is None and gpu_id is None:
97
+ available_gpu_devices = [str(g) for g in GPUtil.getAvailable(limit=8, maxMemory=0.5)]
98
+ if gpu_id is not None:
99
+ available_gpu_devices = [i for i in str(gpu_id)]
100
+ if len(available_gpu_devices) < num_gpus:
101
+ raise ValueError("{} GPU(s) required for the command {} is not available".format(num_gpus, command))
102
+ active_devices = ",".join(available_gpu_devices[:num_gpus])
103
+ if which_epoch is not None:
104
+ which_epoch = " --epoch %s " % which_epoch
105
+ else:
106
+ which_epoch = ""
107
+ command = "CUDA_VISIBLE_DEVICES={} {} {}".format(active_devices, command, which_epoch)
108
+ if continue_train:
109
+ command += " --continue_train "
110
+
111
+ # available_gpu_devices = [str(g) for g in GPUtil.getAvailable(limit=8, maxMemory=0.8)]
112
+ available_gpu_devices = available_gpu_devices[num_gpus:]
113
+
114
+ return command
115
+
116
+ def send_command(self, exp_id, command, dry=False, continue_train=False):
117
+ command = self.refine_command(command, None, continue_train=continue_train)
118
+ pane_name = "experiments_{windowid}.{paneid}".format(windowid=exp_id // self.pane_per_window,
119
+ paneid=exp_id % self.pane_per_window)
120
+ if dry is False:
121
+ os.system("tmux send-keys -t {} \"{}\" Enter".format(pane_name, command))
122
+
123
+ print("{}: {}".format(pane_name, command))
124
+ return pane_name
125
+
126
+ def run_command(self, command, ids, which_epoch=None, continue_train=False, gpu_id=None):
127
+ if type(command) is not list:
128
+ command = [command]
129
+ if ids is None:
130
+ ids = list(range(len(command)))
131
+ if type(ids) is not list:
132
+ ids = [ids]
133
+
134
+ for id in ids:
135
+ this_command = command[id]
136
+ refined_command = self.refine_command(this_command, which_epoch, continue_train=continue_train, gpu_id=gpu_id)
137
+ print(refined_command)
138
+ os.system(refined_command)
139
+
140
+ def commands(self):
141
+ return []
142
+
143
+ def launch(self, ids, test=False, dry=False, continue_train=False):
144
+ commands = self.test_commands() if test else self.commands()
145
+ if type(ids) is list:
146
+ commands = [commands[i] for i in ids]
147
+ if not self.tmux_prepared:
148
+ self.prepare_tmux_panes(len(commands), dry)
149
+ assert self.tmux_prepared
150
+
151
+ for i, command in enumerate(commands):
152
+ self.send_command(i, command, dry, continue_train=continue_train)
153
+
154
+ def dry(self):
155
+ self.launch(dry=True)
156
+
157
+ def stop(self):
158
+ num_experiments = len(self.commands())
159
+ self.pane_per_window = 4
160
+ self.n_windows = int(math.ceil(num_experiments / self.pane_per_window))
161
+ for w in range(self.n_windows):
162
+ window_name = "experiments_{}".format(w)
163
+ for i in range(self.pane_per_window):
164
+ os.system("tmux send-keys -t {window}.{pane} C-c".format(window=window_name, pane=i))
165
+
166
+ def close(self):
167
+ num_experiments = len(self.commands())
168
+ self.pane_per_window = 1
169
+ self.n_windows = int(math.ceil(num_experiments / self.pane_per_window))
170
+ for w in range(self.n_windows):
171
+ window_name = "experiments_{}".format(w)
172
+ os.system("tmux kill-window -t {}".format(window_name))
173
+
174
+ def print_names(self, ids, test=False):
175
+ if test:
176
+ cmds = self.test_commands()
177
+ else:
178
+ cmds = self.commands()
179
+ if type(ids) is list:
180
+ cmds = [cmds[i] for i in ids]
181
+
182
+ for cmdid, cmd in enumerate(cmds):
183
+ name = grab_pattern(r'--name ([^ ]+)', cmd)
184
+ print(name)
185
+
186
+ def create_comparison_html(self, expr_name, ids, subdir, title, phase):
187
+ cmds = self.test_commands()
188
+ if type(ids) is list:
189
+ cmds = [cmds[i] for i in ids]
190
+
191
+ no_easy_label = True
192
+ dirs = []
193
+ labels = []
194
+ for cmdid, cmd in enumerate(cmds):
195
+ name = grab_pattern(r'--name ([^ ]+)', cmd)
196
+ which_epoch = grab_pattern(r'--epoch ([^ ]+)', cmd)
197
+ if which_epoch is None:
198
+ which_epoch = "latest"
199
+ label = grab_pattern(r'--easy_label "([^"]+)"', cmd)
200
+ if label is None:
201
+ label = name
202
+ else:
203
+ no_easy_label = False
204
+ labels.append(label)
205
+ dir = "results/%s/%s_%s/%s/" % (name, phase, which_epoch, subdir)
206
+ dirs.append(dir)
207
+
208
+ commonprefix = findcommonstart(labels) if no_easy_label else ""
209
+ labels = ['"' + label[len(commonprefix):] + '"' for label in labels]
210
+ dirstr = ' '.join(dirs)
211
+ labelstr = ' '.join(labels)
212
+
213
+ command = "python ~/tools/html.py --web_dir_prefix results/comparison_ --name %s --dirs %s --labels %s --image_width 256" % (expr_name + '_' + title, dirstr, labelstr)
214
+ print(command)
215
+ os.system(command)
asp/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from asp.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
asp/models/asp_loss.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class AdaptiveSupervisedPatchNCELoss(nn.Module):
9
+
10
+ def __init__(self, opt):
11
+ super().__init__()
12
+ self.opt = opt
13
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
14
+ self.mask_dtype = torch.bool
15
+ self.total_epochs = opt.n_epochs + opt.n_epochs_decay
16
+
17
+ def forward(self, feat_q, feat_k, current_epoch=-1):
18
+ num_patches = feat_q.shape[0]
19
+ dim = feat_q.shape[1]
20
+ feat_k = feat_k.detach()
21
+
22
+ # pos logit
23
+ l_pos = torch.bmm(
24
+ feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
25
+ l_pos = l_pos.view(num_patches, 1)
26
+
27
+ # neg logit
28
+
29
+ # Should the negatives from the other samples of a minibatch be utilized?
30
+ # In CUT and FastCUT, we found that it's best to only include negatives
31
+ # from the same image. Therefore, we set
32
+ # --nce_includes_all_negatives_from_minibatch as False
33
+ # However, for single-image translation, the minibatch consists of
34
+ # crops from the "same" high-resolution image.
35
+ # Therefore, we will include the negatives from the entire minibatch.
36
+ if self.opt.nce_includes_all_negatives_from_minibatch:
37
+ # reshape features as if they are all negatives of minibatch of size 1.
38
+ batch_dim_for_bmm = 1
39
+ else:
40
+ batch_dim_for_bmm = self.opt.batch_size
41
+
42
+ # reshape features to batch size
43
+ feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
44
+ feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
45
+ npatches = feat_q.size(1)
46
+ l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
47
+
48
+ # diagonal entries are similarity between same features, and hence meaningless.
49
+ # just fill the diagonal with very small number, which is exp(-10) and almost zero
50
+ diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
51
+ l_neg_curbatch.masked_fill_(diagonal, -10.0)
52
+ l_neg = l_neg_curbatch.view(-1, npatches)
53
+
54
+ out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
55
+
56
+ loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
57
+ device=feat_q.device))
58
+
59
+ if self.opt.asp_loss_mode == 'none':
60
+ return loss
61
+
62
+ scheduler, lookup = self.opt.asp_loss_mode.split('_')[:2]
63
+ # Compute scheduling
64
+ t = (current_epoch - 1) / self.total_epochs
65
+ if scheduler == 'sigmoid':
66
+ p = 1 / (1 + np.exp((t - 0.5) * 10))
67
+ elif scheduler == 'linear':
68
+ p = 1 - t
69
+ elif scheduler == 'lambda':
70
+ k = 1 - self.opt.n_epochs_decay / self.total_epochs
71
+ m = 1 / (1 - k)
72
+ p = m - m * t if t >= k else 1.0
73
+ elif scheduler == 'zero':
74
+ p = 1.0
75
+ else:
76
+ raise ValueError(f"Unrecognized scheduler: {scheduler}")
77
+ # Weight lookups
78
+ w0 = 1.0
79
+ x = l_pos.squeeze().detach()
80
+ if lookup == 'top':
81
+ x = torch.where(x > 0.0, x, torch.zeros_like(x))
82
+ w1 = torch.sqrt(1 - (x - 1) ** 2)
83
+ elif lookup == 'linear':
84
+ w1 = torch.relu(x)
85
+ elif lookup == 'bell':
86
+ sigma, mu, sc = 1, 0, 4
87
+ w1 = 1 / (sigma * np.sqrt(2 * torch.pi)) * torch.exp(-((x - 0.5) * sc - mu) ** 2 / (2 * sigma ** 2))
88
+ elif lookup == 'uniform':
89
+ w1 = torch.ones_like(x)
90
+ else:
91
+ raise ValueError(f"Unrecognized lookup: {lookup}")
92
+ # Apply weights with schedule
93
+ w = p * w0 + (1 - p) * w1
94
+ # Normalize
95
+ w = w / w.sum() * len(w)
96
+ loss = loss * w
97
+ return loss
asp/models/base_model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this fucntion, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): specify the images that you want to display and save.
29
+ -- self.visual_names (str list): define networks used in our training.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
+ torch.backends.cudnn.benchmark = True
39
+ self.loss_names = []
40
+ self.model_names = []
41
+ self.visual_names = []
42
+ self.optimizers = []
43
+ self.image_paths = []
44
+ self.metric = 0 # used for learning rate policy 'plateau'
45
+
46
+ @staticmethod
47
+ def dict_grad_hook_factory(add_func=lambda x: x):
48
+ saved_dict = dict()
49
+
50
+ def hook_gen(name):
51
+ def grad_hook(grad):
52
+ saved_vals = add_func(grad)
53
+ saved_dict[name] = saved_vals
54
+ return grad_hook
55
+ return hook_gen, saved_dict
56
+
57
+ @staticmethod
58
+ def modify_commandline_options(parser, is_train):
59
+ """Add new model-specific options, and rewrite default values for existing options.
60
+
61
+ Parameters:
62
+ parser -- original option parser
63
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
64
+
65
+ Returns:
66
+ the modified parser.
67
+ """
68
+ return parser
69
+
70
+ @abstractmethod
71
+ def set_input(self, input):
72
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
73
+
74
+ Parameters:
75
+ input (dict): includes the data itself and its metadata information.
76
+ """
77
+ pass
78
+
79
+ @abstractmethod
80
+ def forward(self):
81
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
82
+ pass
83
+
84
+ @abstractmethod
85
+ def optimize_parameters(self):
86
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
87
+ pass
88
+
89
+ def setup(self, opt):
90
+ """Load and print networks; create schedulers
91
+
92
+ Parameters:
93
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
94
+ """
95
+ if self.isTrain:
96
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
97
+ if not self.isTrain or opt.continue_train:
98
+ load_suffix = opt.epoch
99
+ self.load_networks(load_suffix)
100
+
101
+ self.print_networks(opt.verbose)
102
+
103
+ def parallelize(self):
104
+ for name in self.model_names:
105
+ if isinstance(name, str):
106
+ net = getattr(self, 'net' + name)
107
+ setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
108
+
109
+ def data_dependent_initialize(self, data):
110
+ pass
111
+
112
+ def eval(self):
113
+ """Make models eval mode during test time"""
114
+ for name in self.model_names:
115
+ if isinstance(name, str):
116
+ net = getattr(self, 'net' + name)
117
+ net.eval()
118
+
119
+ def test(self):
120
+ """Forward function used in test time.
121
+
122
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
123
+ It also calls <compute_visuals> to produce additional visualization results
124
+ """
125
+ with torch.no_grad():
126
+ self.forward()
127
+ self.compute_visuals()
128
+
129
+ def compute_visuals(self):
130
+ """Calculate additional output images for visdom and HTML visualization"""
131
+ pass
132
+
133
+ def get_image_paths(self):
134
+ """ Return image paths that are used to load current data"""
135
+ return self.image_paths
136
+
137
+ def update_learning_rate(self):
138
+ """Update learning rates for all the networks; called at the end of every epoch"""
139
+ for scheduler in self.schedulers:
140
+ if self.opt.lr_policy == 'plateau':
141
+ scheduler.step(self.metric)
142
+ else:
143
+ scheduler.step()
144
+
145
+ lr = self.optimizers[0].param_groups[0]['lr']
146
+ print('learning rate = %.7f' % lr)
147
+
148
+ def get_current_visuals(self):
149
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
150
+ visual_ret = OrderedDict()
151
+ for name in self.visual_names:
152
+ if isinstance(name, str):
153
+ visual_ret[name] = getattr(self, name)
154
+ return visual_ret
155
+
156
+ def get_current_losses(self):
157
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
158
+ errors_ret = OrderedDict()
159
+ for name in self.loss_names:
160
+ if isinstance(name, str):
161
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
162
+ return errors_ret
163
+
164
+ def save_networks(self, epoch):
165
+ """Save all the networks to the disk.
166
+
167
+ Parameters:
168
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
169
+ """
170
+ for name in self.model_names:
171
+ if isinstance(name, str):
172
+ save_filename = '%s_net_%s.pth' % (epoch, name)
173
+ save_path = os.path.join(self.save_dir, save_filename)
174
+ net = getattr(self, 'net' + name)
175
+
176
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
177
+ torch.save(net.module.cpu().state_dict(), save_path)
178
+ net.cuda(self.gpu_ids[0])
179
+ else:
180
+ torch.save(net.cpu().state_dict(), save_path)
181
+
182
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
183
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
184
+ key = keys[i]
185
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
186
+ if module.__class__.__name__.startswith('InstanceNorm') and \
187
+ (key == 'running_mean' or key == 'running_var'):
188
+ if getattr(module, key) is None:
189
+ state_dict.pop('.'.join(keys))
190
+ if module.__class__.__name__.startswith('InstanceNorm') and \
191
+ (key == 'num_batches_tracked'):
192
+ state_dict.pop('.'.join(keys))
193
+ else:
194
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
195
+
196
+ def load_networks(self, epoch):
197
+ """Load all the networks from the disk.
198
+
199
+ Parameters:
200
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
201
+ """
202
+ for name in self.model_names:
203
+ if isinstance(name, str):
204
+ load_filename = '%s_net_%s.pth' % (epoch, name)
205
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
206
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
207
+ else:
208
+ load_dir = self.save_dir
209
+
210
+ load_path = os.path.join(load_dir, load_filename)
211
+ net = getattr(self, 'net' + name)
212
+ if isinstance(net, torch.nn.DataParallel):
213
+ net = net.module
214
+ print('loading the model from %s' % load_path)
215
+ # if you are using PyTorch newer than 0.4 (e.g., built from
216
+ # GitHub source), you can remove str() on self.device
217
+ state_dict = torch.load(load_path, map_location=str(self.device))
218
+ if hasattr(state_dict, '_metadata'):
219
+ del state_dict._metadata
220
+
221
+ # patch InstanceNorm checkpoints prior to 0.4
222
+ # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
223
+ # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
224
+ net.load_state_dict(state_dict)
225
+
226
+ def print_networks(self, verbose):
227
+ """Print the total number of parameters in the network and (if verbose) network architecture
228
+
229
+ Parameters:
230
+ verbose (bool) -- if verbose: print the network architecture
231
+ """
232
+ print('---------- Networks initialized -------------')
233
+ for name in self.model_names:
234
+ if isinstance(name, str):
235
+ net = getattr(self, 'net' + name)
236
+ num_params = 0
237
+ for param in net.parameters():
238
+ num_params += param.numel()
239
+ if verbose:
240
+ print(net)
241
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
242
+ print('-----------------------------------------------')
243
+
244
+ def set_requires_grad(self, nets, requires_grad=False):
245
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
246
+ Parameters:
247
+ nets (network list) -- a list of networks
248
+ requires_grad (bool) -- whether the networks require gradients or not
249
+ """
250
+ if not isinstance(nets, list):
251
+ nets = [nets]
252
+ for net in nets:
253
+ if net is not None:
254
+ for param in net.parameters():
255
+ param.requires_grad = requires_grad
256
+
257
+ def generate_visuals_for_evaluation(self, data, mode):
258
+ return {}
asp/models/cpt_model.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from asp.models.asp_loss import AdaptiveSupervisedPatchNCELoss
5
+ from .base_model import BaseModel
6
+ from . import networks
7
+ from .patchnce import PatchNCELoss
8
+ from .gauss_pyramid import Gauss_Pyramid_Conv
9
+ import asp.util.util as util
10
+
11
+
12
+ class CPTModel(BaseModel):
13
+ """ Contrastive Paired Translation (CPT).
14
+ """
15
+ @staticmethod
16
+ def modify_commandline_options(parser, is_train=True):
17
+ """ Configures options specific for CUT model
18
+ """
19
+ parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
20
+
21
+ parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
22
+ parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
23
+ parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
24
+ parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
25
+ parser.add_argument('--nce_includes_all_negatives_from_minibatch',
26
+ type=util.str2bool, nargs='?', const=True, default=False,
27
+ help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
28
+ parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
29
+ parser.add_argument('--netF_nc', type=int, default=256)
30
+ parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
31
+ parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
32
+ parser.add_argument('--flip_equivariance',
33
+ type=util.str2bool, nargs='?', const=True, default=False,
34
+ help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
35
+ parser.set_defaults(pool_size=0) # no image pooling
36
+
37
+ # FDL:
38
+ parser.add_argument('--lambda_gp', type=float, default=1.0, help='weight for Gaussian Pyramid reconstruction loss')
39
+ parser.add_argument('--gp_weights', type=str, default='uniform', help='weights for reconstruction pyramids.')
40
+ parser.add_argument('--lambda_asp', type=float, default=0.0, help='weight for ASP loss')
41
+ parser.add_argument('--asp_loss_mode', type=str, default='none', help='"scheduler_lookup" options for the ASP loss. Options for both are listed in Fig. 3 of the paper.')
42
+ parser.add_argument('--n_downsampling', type=int, default=2, help='# of downsample in G')
43
+
44
+ opt, _ = parser.parse_known_args()
45
+
46
+ # Set default parameters for CUT and FastCUT
47
+ if opt.CUT_mode.lower() == "cut":
48
+ parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
49
+ elif opt.CUT_mode.lower() == "fastcut":
50
+ parser.set_defaults(
51
+ nce_idt=False, lambda_NCE=10.0, flip_equivariance=False,
52
+ n_epochs=20, n_epochs_decay=10
53
+ )
54
+ else:
55
+ raise ValueError(opt.CUT_mode)
56
+
57
+ return parser
58
+
59
+ def __init__(self, opt):
60
+ BaseModel.__init__(self, opt)
61
+
62
+ # specify the training losses you want to print out.
63
+ # The training/test scripts will call <BaseModel.get_current_losses>
64
+ self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE']
65
+ self.visual_names = ['real_A', 'fake_B', 'real_B']
66
+ self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
67
+
68
+ if opt.nce_idt and self.isTrain:
69
+ self.loss_names += ['NCE_Y']
70
+ self.visual_names += ['idt_B']
71
+
72
+ if self.isTrain:
73
+ self.model_names = ['G', 'F', 'D']
74
+ else: # during test time, only load G
75
+ self.model_names = ['G']
76
+
77
+ # define networks (both generator and discriminator)
78
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
79
+ self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
80
+
81
+ if self.isTrain:
82
+ self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
83
+
84
+ # define loss functions
85
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
86
+ self.criterionNCE = PatchNCELoss(opt).to(self.device)
87
+ self.criterionIdt = torch.nn.L1Loss().to(self.device)
88
+
89
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
90
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
91
+ self.optimizers.append(self.optimizer_G)
92
+ self.optimizers.append(self.optimizer_D)
93
+
94
+ if self.opt.lambda_gp > 0:
95
+ self.P = Gauss_Pyramid_Conv(num_high=5)
96
+ self.criterionGP = torch.nn.L1Loss().to(self.device)
97
+ if self.opt.gp_weights == 'uniform':
98
+ self.gp_weights = [1.0] * 6
99
+ else:
100
+ self.gp_weights = eval(self.opt.gp_weights)
101
+ self.loss_names += ['GP']
102
+
103
+ if self.opt.lambda_asp > 0:
104
+ self.criterionASP = AdaptiveSupervisedPatchNCELoss(self.opt).to(self.device)
105
+ self.loss_names += ['ASP']
106
+
107
+
108
+ def data_dependent_initialize(self, data):
109
+ """
110
+ The feature network netF is defined in terms of the shape of the intermediate, extracted
111
+ features of the encoder portion of netG. Because of this, the weights of netF are
112
+ initialized at the first feedforward pass with some input images.
113
+ Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
114
+ """
115
+ bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
116
+ self.set_input(data)
117
+ self.real_A = self.real_A[:bs_per_gpu]
118
+ self.real_B = self.real_B[:bs_per_gpu]
119
+ self.forward() # compute fake images: G(A)
120
+ if self.opt.isTrain:
121
+ self.compute_D_loss().backward() # calculate gradients for D
122
+ self.compute_G_loss().backward() # calculate graidents for G
123
+ if self.opt.lambda_NCE > 0.0 or self.opt.lambda_asp > 0.0:
124
+ self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
125
+ self.optimizers.append(self.optimizer_F)
126
+
127
+ def optimize_parameters(self):
128
+ # forward
129
+ self.forward()
130
+
131
+ # update D
132
+ self.set_requires_grad(self.netD, True)
133
+ self.optimizer_D.zero_grad()
134
+ self.loss_D = self.compute_D_loss()
135
+ self.loss_D.backward()
136
+ self.optimizer_D.step()
137
+ # update G
138
+ self.set_requires_grad(self.netD, False)
139
+ self.optimizer_G.zero_grad()
140
+ if self.opt.netF == 'mlp_sample':
141
+ self.optimizer_F.zero_grad()
142
+ self.loss_G = self.compute_G_loss()
143
+ self.loss_G.backward()
144
+ self.optimizer_G.step()
145
+ if self.opt.netF == 'mlp_sample':
146
+ self.optimizer_F.step()
147
+
148
+ def set_input(self, input):
149
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
150
+ Parameters:
151
+ input (dict): include the data itself and its metadata information.
152
+ The option 'direction' can be used to swap domain A and domain B.
153
+ """
154
+ AtoB = self.opt.direction == 'AtoB'
155
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
156
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
157
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
158
+
159
+ if 'current_epoch' in input:
160
+ self.current_epoch = input['current_epoch']
161
+ if 'current_iter' in input:
162
+ self.current_iter = input['current_iter']
163
+
164
+ def forward(self):
165
+ # self.netG.print()
166
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
167
+ self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
168
+ if self.opt.flip_equivariance:
169
+ self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
170
+ if self.flipped_for_equivariance:
171
+ self.real = torch.flip(self.real, [3])
172
+
173
+ self.fake = self.netG(self.real, layers=[])
174
+ self.fake_B = self.fake[:self.real_A.size(0)]
175
+ if self.opt.nce_idt:
176
+ self.idt_B = self.fake[self.real_A.size(0):]
177
+
178
+ def compute_D_loss(self):
179
+ """Calculate GAN loss for the discriminator"""
180
+ fake = self.fake_B.detach()
181
+ # Fake; stop backprop to the generator by detaching fake_B
182
+ pred_fake = self.netD(fake)
183
+ self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
184
+ # Real
185
+ self.pred_real = self.netD(self.real_B)
186
+ loss_D_real = self.criterionGAN(self.pred_real, True)
187
+ self.loss_D_real = loss_D_real.mean()
188
+
189
+ # combine loss and calculate gradients
190
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
191
+ return self.loss_D
192
+
193
+ def compute_G_loss(self):
194
+ """Calculate GAN and NCE loss for the generator"""
195
+ fake = self.fake_B
196
+
197
+ feat_real_A = self.netG(self.real_A, self.nce_layers, encode_only=True)
198
+ feat_fake_B = self.netG(self.fake_B, self.nce_layers, encode_only=True)
199
+ feat_real_B = self.netG(self.real_B, self.nce_layers, encode_only=True)
200
+ if self.opt.nce_idt:
201
+ feat_idt_B = self.netG(self.idt_B, self.nce_layers, encode_only=True)
202
+
203
+ # First, G(A) should fake the discriminator
204
+ if self.opt.lambda_GAN > 0.0:
205
+ pred_fake = self.netD(fake)
206
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
207
+ else:
208
+ self.loss_G_GAN = 0.0
209
+
210
+ if self.opt.lambda_NCE > 0.0:
211
+ self.loss_NCE = self.calculate_NCE_loss(feat_real_A, feat_fake_B, self.netF, self.nce_layers)
212
+ else:
213
+ self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
214
+ loss_NCE_all = self.loss_NCE
215
+
216
+ if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
217
+ self.loss_NCE_Y = self.calculate_NCE_loss(feat_real_B, feat_idt_B, self.netF, self.nce_layers)
218
+ else:
219
+ self.loss_NCE_Y = 0.0
220
+ loss_NCE_all += self.loss_NCE_Y
221
+
222
+ # FDL: NCE between the noisy pairs (fake_B and real_B)
223
+ if self.opt.lambda_asp > 0:
224
+ self.loss_ASP = self.calculate_NCE_loss(feat_real_B, feat_fake_B, self.netF, self.nce_layers, paired=True)
225
+ else:
226
+ self.loss_ASP = 0.0
227
+ loss_NCE_all += self.loss_ASP
228
+
229
+ # FDL: compute loss on Gaussian pyramids
230
+ if self.opt.lambda_gp > 0:
231
+ p_fake_B = self.P(self.fake_B)
232
+ p_real_B = self.P(self.real_B)
233
+ loss_pyramid = [self.criterionGP(pf, pr) for pf, pr in zip(p_fake_B, p_real_B)]
234
+ weights = self.gp_weights
235
+ loss_pyramid = [l * w for l, w in zip(loss_pyramid, weights)]
236
+ self.loss_GP = torch.mean(torch.stack(loss_pyramid)) * self.opt.lambda_gp
237
+ else:
238
+ self.loss_GP = 0
239
+
240
+ self.loss_G = self.loss_G_GAN + loss_NCE_all + self.loss_GP
241
+ return self.loss_G
242
+
243
+ def calculate_NCE_loss(self, feat_src, feat_tgt, netF, nce_layers, paired=False):
244
+ n_layers = len(feat_src)
245
+ feat_q = feat_tgt
246
+
247
+ if self.opt.flip_equivariance and self.flipped_for_equivariance:
248
+ feat_q = [torch.flip(fq, [3]) for fq in feat_q]
249
+ feat_k = feat_src
250
+ feat_k_pool, sample_ids = netF(feat_k, self.opt.num_patches, None)
251
+ feat_q_pool, _ = netF(feat_q, self.opt.num_patches, sample_ids)
252
+
253
+ total_nce_loss = 0.0
254
+ for f_q, f_k in zip(feat_q_pool, feat_k_pool):
255
+ if paired:
256
+ loss = self.criterionASP(f_q, f_k, self.current_epoch) * self.opt.lambda_asp
257
+ else:
258
+ loss = self.criterionNCE(f_q, f_k) * self.opt.lambda_NCE
259
+ total_nce_loss += loss.mean()
260
+
261
+ return total_nce_loss / n_layers
asp/models/cut_model.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from .base_model import BaseModel
4
+ from . import networks
5
+ from .patchnce import PatchNCELoss
6
+ import util.util as util
7
+
8
+
9
+ class CUTModel(BaseModel):
10
+ """ This class implements CUT and FastCUT model, described in the paper
11
+ Contrastive Learning for Unpaired Image-to-Image Translation
12
+ Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
13
+ ECCV, 2020
14
+
15
+ The code borrows heavily from the PyTorch implementation of CycleGAN
16
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
17
+ """
18
+ @staticmethod
19
+ def modify_commandline_options(parser, is_train=True):
20
+ """ Configures options specific for CUT model
21
+ """
22
+ parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
23
+
24
+ parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
25
+ parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
26
+ parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
27
+ parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
28
+ parser.add_argument('--nce_includes_all_negatives_from_minibatch',
29
+ type=util.str2bool, nargs='?', const=True, default=False,
30
+ help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
31
+ parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
32
+ parser.add_argument('--netF_nc', type=int, default=256)
33
+ parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
34
+ parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
35
+ parser.add_argument('--flip_equivariance',
36
+ type=util.str2bool, nargs='?', const=True, default=False,
37
+ help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
38
+
39
+ parser.set_defaults(pool_size=0) # no image pooling
40
+
41
+ opt, _ = parser.parse_known_args()
42
+
43
+ # Set default parameters for CUT and FastCUT
44
+ if opt.CUT_mode.lower() == "cut":
45
+ parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
46
+ elif opt.CUT_mode.lower() == "fastcut":
47
+ parser.set_defaults(
48
+ nce_idt=False, lambda_NCE=10.0, flip_equivariance=True,
49
+ n_epochs=150, n_epochs_decay=50
50
+ )
51
+ else:
52
+ raise ValueError(opt.CUT_mode)
53
+
54
+ return parser
55
+
56
+ def __init__(self, opt):
57
+ BaseModel.__init__(self, opt)
58
+
59
+ # specify the training losses you want to print out.
60
+ # The training/test scripts will call <BaseModel.get_current_losses>
61
+ self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE']
62
+ self.visual_names = ['real_A', 'fake_B', 'real_B']
63
+ self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
64
+
65
+ if opt.nce_idt and self.isTrain:
66
+ self.loss_names += ['NCE_Y']
67
+ self.visual_names += ['idt_B']
68
+
69
+ if self.isTrain:
70
+ self.model_names = ['G', 'F', 'D']
71
+ else: # during test time, only load G
72
+ self.model_names = ['G']
73
+
74
+ # define networks (both generator and discriminator)
75
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
76
+ self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
77
+
78
+ if self.isTrain:
79
+ self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
80
+
81
+ # define loss functions
82
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
83
+ self.criterionNCE = []
84
+
85
+ for nce_layer in self.nce_layers:
86
+ self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
87
+
88
+ self.criterionIdt = torch.nn.L1Loss().to(self.device)
89
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
90
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
91
+ self.optimizers.append(self.optimizer_G)
92
+ self.optimizers.append(self.optimizer_D)
93
+
94
+ def data_dependent_initialize(self, data):
95
+ """
96
+ The feature network netF is defined in terms of the shape of the intermediate, extracted
97
+ features of the encoder portion of netG. Because of this, the weights of netF are
98
+ initialized at the first feedforward pass with some input images.
99
+ Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
100
+ """
101
+ bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
102
+ self.set_input(data)
103
+ self.real_A = self.real_A[:bs_per_gpu]
104
+ self.real_B = self.real_B[:bs_per_gpu]
105
+ self.forward() # compute fake images: G(A)
106
+ if self.opt.isTrain:
107
+ self.compute_D_loss().backward() # calculate gradients for D
108
+ self.compute_G_loss().backward() # calculate graidents for G
109
+ if self.opt.lambda_NCE > 0.0:
110
+ self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
111
+ self.optimizers.append(self.optimizer_F)
112
+
113
+ def optimize_parameters(self):
114
+ # forward
115
+ self.forward()
116
+
117
+ # update D
118
+ self.set_requires_grad(self.netD, True)
119
+ self.optimizer_D.zero_grad()
120
+ self.loss_D = self.compute_D_loss()
121
+ self.loss_D.backward()
122
+ self.optimizer_D.step()
123
+
124
+ # update G
125
+ self.set_requires_grad(self.netD, False)
126
+ self.optimizer_G.zero_grad()
127
+ if self.opt.netF == 'mlp_sample':
128
+ self.optimizer_F.zero_grad()
129
+ self.loss_G = self.compute_G_loss()
130
+ self.loss_G.backward()
131
+ self.optimizer_G.step()
132
+ if self.opt.netF == 'mlp_sample':
133
+ self.optimizer_F.step()
134
+
135
+ def set_input(self, input):
136
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
137
+ Parameters:
138
+ input (dict): include the data itself and its metadata information.
139
+ The option 'direction' can be used to swap domain A and domain B.
140
+ """
141
+ AtoB = self.opt.direction == 'AtoB'
142
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
143
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
144
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
145
+
146
+ def forward(self):
147
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
148
+ self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
149
+ if self.opt.flip_equivariance:
150
+ self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
151
+ if self.flipped_for_equivariance:
152
+ self.real = torch.flip(self.real, [3])
153
+
154
+ self.fake = self.netG(self.real)
155
+ self.fake_B = self.fake[:self.real_A.size(0)]
156
+ if self.opt.nce_idt:
157
+ self.idt_B = self.fake[self.real_A.size(0):]
158
+
159
+ def compute_D_loss(self):
160
+ """Calculate GAN loss for the discriminator"""
161
+ fake = self.fake_B.detach()
162
+ # Fake; stop backprop to the generator by detaching fake_B
163
+ pred_fake = self.netD(fake)
164
+ self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
165
+ # Real
166
+ self.pred_real = self.netD(self.real_B)
167
+ loss_D_real = self.criterionGAN(self.pred_real, True)
168
+ self.loss_D_real = loss_D_real.mean()
169
+
170
+ # combine loss and calculate gradients
171
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
172
+ return self.loss_D
173
+
174
+ def compute_G_loss(self):
175
+ """Calculate GAN and NCE loss for the generator"""
176
+ fake = self.fake_B
177
+ # First, G(A) should fake the discriminator
178
+ if self.opt.lambda_GAN > 0.0:
179
+ pred_fake = self.netD(fake)
180
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
181
+ else:
182
+ self.loss_G_GAN = 0.0
183
+
184
+ if self.opt.lambda_NCE > 0.0:
185
+ self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
186
+ else:
187
+ self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
188
+
189
+ if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
190
+ self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
191
+ loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
192
+ else:
193
+ loss_NCE_both = self.loss_NCE
194
+
195
+ self.loss_G = self.loss_G_GAN + loss_NCE_both
196
+ return self.loss_G
197
+
198
+ def calculate_NCE_loss(self, src, tgt):
199
+ n_layers = len(self.nce_layers)
200
+ feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
201
+
202
+ if self.opt.flip_equivariance and self.flipped_for_equivariance:
203
+ feat_q = [torch.flip(fq, [3]) for fq in feat_q]
204
+
205
+ feat_k = self.netG(src, self.nce_layers, encode_only=True)
206
+ feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
207
+ feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
208
+
209
+ total_nce_loss = 0.0
210
+ for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
211
+ loss = crit(f_q, f_k) * self.opt.lambda_NCE
212
+ total_nce_loss += loss.mean()
213
+
214
+ return total_nce_loss / n_layers
asp/models/gauss_pyramid.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class Gauss_Pyramid_Conv(nn.Module):
5
+ """
6
+ Code borrowed from: https://github.com/csjliang/LPTN
7
+ """
8
+ def __init__(self, num_high=3):
9
+ super(Gauss_Pyramid_Conv, self).__init__()
10
+
11
+ self.num_high = num_high
12
+ self.kernel = self.gauss_kernel()
13
+
14
+ def gauss_kernel(self, device=torch.device('cuda'), channels=3):
15
+ kernel = torch.tensor([[1., 4., 6., 4., 1],
16
+ [4., 16., 24., 16., 4.],
17
+ [6., 24., 36., 24., 6.],
18
+ [4., 16., 24., 16., 4.],
19
+ [1., 4., 6., 4., 1.]])
20
+ kernel /= 256.
21
+ kernel = kernel.repeat(channels, 1, 1, 1)
22
+ kernel = kernel.to(device)
23
+ return kernel
24
+
25
+ def downsample(self, x):
26
+ return x[:, :, ::2, ::2]
27
+
28
+ def conv_gauss(self, img, kernel):
29
+ img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
30
+ out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
31
+ return out
32
+
33
+ def forward(self, img):
34
+ current = img
35
+ pyr = []
36
+ for _ in range(self.num_high):
37
+ filtered = self.conv_gauss(current, self.kernel)
38
+ pyr.append(filtered)
39
+ down = self.downsample(filtered)
40
+ current = down
41
+ pyr.append(current)
42
+ return pyr
asp/models/networks.py ADDED
@@ -0,0 +1,1422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init
6
+ import functools
7
+ from torch.optim import lr_scheduler
8
+ import numpy as np
9
+
10
+ ###############################################################################
11
+ # Helper Functions
12
+ ###############################################################################
13
+
14
+
15
+ def get_filter(filt_size=3):
16
+ if(filt_size == 1):
17
+ a = np.array([1., ])
18
+ elif(filt_size == 2):
19
+ a = np.array([1., 1.])
20
+ elif(filt_size == 3):
21
+ a = np.array([1., 2., 1.])
22
+ elif(filt_size == 4):
23
+ a = np.array([1., 3., 3., 1.])
24
+ elif(filt_size == 5):
25
+ a = np.array([1., 4., 6., 4., 1.])
26
+ elif(filt_size == 6):
27
+ a = np.array([1., 5., 10., 10., 5., 1.])
28
+ elif(filt_size == 7):
29
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
30
+
31
+ filt = torch.Tensor(a[:, None] * a[None, :])
32
+ filt = filt / torch.sum(filt)
33
+
34
+ return filt
35
+
36
+
37
+ class Downsample(nn.Module):
38
+ def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
39
+ super(Downsample, self).__init__()
40
+ self.filt_size = filt_size
41
+ self.pad_off = pad_off
42
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
43
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
44
+ self.stride = stride
45
+ self.off = int((self.stride - 1) / 2.)
46
+ self.channels = channels
47
+
48
+ filt = get_filter(filt_size=self.filt_size)
49
+ self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
50
+
51
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
52
+
53
+ def forward(self, inp):
54
+ if(self.filt_size == 1):
55
+ if(self.pad_off == 0):
56
+ return inp[:, :, ::self.stride, ::self.stride]
57
+ else:
58
+ return self.pad(inp)[:, :, ::self.stride, ::self.stride]
59
+ else:
60
+ return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
61
+
62
+
63
+ class Upsample2(nn.Module):
64
+ def __init__(self, scale_factor, mode='nearest'):
65
+ super().__init__()
66
+ self.factor = scale_factor
67
+ self.mode = mode
68
+
69
+ def forward(self, x):
70
+ return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
71
+
72
+
73
+ class Upsample(nn.Module):
74
+ def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
75
+ super(Upsample, self).__init__()
76
+ self.filt_size = filt_size
77
+ self.filt_odd = np.mod(filt_size, 2) == 1
78
+ self.pad_size = int((filt_size - 1) / 2)
79
+ self.stride = stride
80
+ self.off = int((self.stride - 1) / 2.)
81
+ self.channels = channels
82
+
83
+ filt = get_filter(filt_size=self.filt_size) * (stride**2)
84
+ self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
85
+
86
+ self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
87
+
88
+ def forward(self, inp):
89
+ ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
90
+ if(self.filt_odd):
91
+ return ret_val
92
+ else:
93
+ return ret_val[:, :, :-1, :-1]
94
+
95
+
96
+ def get_pad_layer(pad_type):
97
+ if(pad_type in ['refl', 'reflect']):
98
+ PadLayer = nn.ReflectionPad2d
99
+ elif(pad_type in ['repl', 'replicate']):
100
+ PadLayer = nn.ReplicationPad2d
101
+ elif(pad_type == 'zero'):
102
+ PadLayer = nn.ZeroPad2d
103
+ else:
104
+ print('Pad type [%s] not recognized' % pad_type)
105
+ return PadLayer
106
+
107
+
108
+ class Identity(nn.Module):
109
+ def forward(self, x):
110
+ return x
111
+
112
+
113
+ def get_norm_layer(norm_type='instance'):
114
+ """Return a normalization layer
115
+
116
+ Parameters:
117
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
118
+
119
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
120
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
121
+ """
122
+ if norm_type == 'batch':
123
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
124
+ elif norm_type == 'instance':
125
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
126
+ elif norm_type == 'none':
127
+ def norm_layer(x):
128
+ return Identity()
129
+ else:
130
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
131
+ return norm_layer
132
+
133
+
134
+ def get_scheduler(optimizer, opt):
135
+ """Return a learning rate scheduler
136
+
137
+ Parameters:
138
+ optimizer -- the optimizer of the network
139
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
140
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
141
+
142
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
143
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
144
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
145
+ See https://pytorch.org/docs/stable/optim.html for more details.
146
+ """
147
+ if opt.lr_policy == 'linear':
148
+ def lambda_rule(epoch):
149
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
150
+ return lr_l
151
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
152
+ elif opt.lr_policy == 'step':
153
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
154
+ elif opt.lr_policy == 'plateau':
155
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
156
+ elif opt.lr_policy == 'cosine':
157
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
158
+ else:
159
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
160
+ return scheduler
161
+
162
+
163
+ def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
164
+ """Initialize network weights.
165
+
166
+ Parameters:
167
+ net (network) -- network to be initialized
168
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
169
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
170
+
171
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
172
+ work better for some applications. Feel free to try yourself.
173
+ """
174
+ def init_func(m): # define the initialization function
175
+ classname = m.__class__.__name__
176
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
177
+ if debug:
178
+ print(classname)
179
+ if init_type == 'normal':
180
+ init.normal_(m.weight.data, 0.0, init_gain)
181
+ elif init_type == 'xavier':
182
+ init.xavier_normal_(m.weight.data, gain=init_gain)
183
+ elif init_type == 'kaiming':
184
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
185
+ elif init_type == 'orthogonal':
186
+ init.orthogonal_(m.weight.data, gain=init_gain)
187
+ else:
188
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
189
+ if hasattr(m, 'bias') and m.bias is not None:
190
+ init.constant_(m.bias.data, 0.0)
191
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
192
+ init.normal_(m.weight.data, 1.0, init_gain)
193
+ init.constant_(m.bias.data, 0.0)
194
+
195
+ net.apply(init_func) # apply the initialization function <init_func>
196
+
197
+
198
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
199
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
200
+ Parameters:
201
+ net (network) -- the network to be initialized
202
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
203
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
204
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
205
+
206
+ Return an initialized network.
207
+ """
208
+ if len(gpu_ids) > 0:
209
+ assert(torch.cuda.is_available())
210
+ net.to(gpu_ids[0])
211
+ # if not amp:
212
+ # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
213
+ if initialize_weights:
214
+ init_weights(net, init_type, init_gain=init_gain, debug=debug)
215
+ return net
216
+
217
+
218
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal',
219
+ init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None):
220
+ """Create a generator
221
+
222
+ Parameters:
223
+ input_nc (int) -- the number of channels in input images
224
+ output_nc (int) -- the number of channels in output images
225
+ ngf (int) -- the number of filters in the last conv layer
226
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
227
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
228
+ use_dropout (bool) -- if use dropout layers.
229
+ init_type (str) -- the name of our initialization method.
230
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
231
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
232
+
233
+ Returns a generator
234
+
235
+ Our current implementation provides two types of generators:
236
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
237
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
238
+
239
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
240
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
241
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
242
+
243
+
244
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
245
+ """
246
+ net = None
247
+ norm_layer = get_norm_layer(norm_type=norm)
248
+
249
+ if netG == 'resnet_9blocks':
250
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
251
+ elif netG == 'resnet_6blocks':
252
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=6, opt=opt)
253
+ elif netG == 'resnet_4blocks':
254
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=4, opt=opt)
255
+ elif netG == 'unet_128':
256
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
257
+ elif netG == 'unet_256':
258
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
259
+ elif netG == 'resnet_cat':
260
+ n_blocks = 8
261
+ net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
262
+ else:
263
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
264
+ return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
265
+
266
+
267
+ def define_F(input_nc, netF, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
268
+ if netF == 'global_pool':
269
+ net = PoolingF()
270
+ elif netF == 'reshape':
271
+ net = ReshapeF()
272
+ elif netF == 'sample':
273
+ net = PatchSampleF(use_mlp=False, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc, opt=opt)
274
+ elif netF == 'mlp_sample':
275
+ net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc, opt=opt)
276
+ elif netF == 'strided_conv':
277
+ net = StridedConvF(init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
278
+ else:
279
+ raise NotImplementedError('projection model name [%s] is not recognized' % netF)
280
+ return init_net(net, init_type, init_gain, gpu_ids)
281
+
282
+
283
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
284
+ """Create a discriminator
285
+
286
+ Parameters:
287
+ input_nc (int) -- the number of channels in input images
288
+ ndf (int) -- the number of filters in the first conv layer
289
+ netD (str) -- the architecture's name: basic | n_layers | pixel
290
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
291
+ norm (str) -- the type of normalization layers used in the network.
292
+ init_type (str) -- the name of the initialization method.
293
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
294
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
295
+
296
+ Returns a discriminator
297
+
298
+ Our current implementation provides three types of discriminators:
299
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
300
+ It can classify whether 70x70 overlapping patches are real or fake.
301
+ Such a patch-level discriminator architecture has fewer parameters
302
+ than a full-image discriminator and can work on arbitrarily-sized images
303
+ in a fully convolutional fashion.
304
+
305
+ [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
306
+ with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
307
+
308
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
309
+ It encourages greater color diversity but has no effect on spatial statistics.
310
+
311
+ The discriminator has been initialized by <init_net>. It uses Leaky RELU for non-linearity.
312
+ """
313
+ net = None
314
+ norm_layer = get_norm_layer(norm_type=norm)
315
+
316
+ if netD == 'basic': # default PatchGAN classifier
317
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias, opt=opt)
318
+ elif netD == 'n_layers': # more options
319
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias, opt=opt)
320
+ elif netD == 'pixel': # classify if each pixel is real or fake
321
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
322
+ else:
323
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
324
+ return init_net(net, init_type, init_gain, gpu_ids,
325
+ initialize_weights=('stylegan2' not in netD))
326
+
327
+
328
+ ##############################################################################
329
+ # Classes
330
+ ##############################################################################
331
+ class GANLoss(nn.Module):
332
+ """Define different GAN objectives.
333
+
334
+ The GANLoss class abstracts away the need to create the target label tensor
335
+ that has the same size as the input.
336
+ """
337
+
338
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
339
+ """ Initialize the GANLoss class.
340
+
341
+ Parameters:
342
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
343
+ target_real_label (bool) - - label for a real image
344
+ target_fake_label (bool) - - label of a fake image
345
+
346
+ Note: Do not use sigmoid as the last layer of Discriminator.
347
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
348
+ """
349
+ super(GANLoss, self).__init__()
350
+ self.register_buffer('real_label', torch.tensor(target_real_label))
351
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
352
+ self.gan_mode = gan_mode
353
+ if gan_mode == 'lsgan':
354
+ self.loss = nn.MSELoss()
355
+ elif gan_mode == 'vanilla':
356
+ self.loss = nn.BCEWithLogitsLoss()
357
+ elif gan_mode in ['wgangp', 'nonsaturating']:
358
+ self.loss = None
359
+ else:
360
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
361
+
362
+ def get_target_tensor(self, prediction, target_is_real):
363
+ """Create label tensors with the same size as the input.
364
+
365
+ Parameters:
366
+ prediction (tensor) - - tpyically the prediction from a discriminator
367
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
368
+
369
+ Returns:
370
+ A label tensor filled with ground truth label, and with the size of the input
371
+ """
372
+
373
+ if target_is_real:
374
+ target_tensor = self.real_label
375
+ else:
376
+ target_tensor = self.fake_label
377
+ return target_tensor.expand_as(prediction)
378
+
379
+ def __call__(self, prediction, target_is_real):
380
+ """Calculate loss given Discriminator's output and grount truth labels.
381
+
382
+ Parameters:
383
+ prediction (tensor) - - tpyically the prediction output from a discriminator
384
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
385
+
386
+ Returns:
387
+ the calculated loss.
388
+ """
389
+ bs = prediction.size(0)
390
+ if self.gan_mode in ['lsgan', 'vanilla']:
391
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
392
+ loss = self.loss(prediction, target_tensor)
393
+ elif self.gan_mode == 'wgangp':
394
+ if target_is_real:
395
+ loss = -prediction.mean()
396
+ else:
397
+ loss = prediction.mean()
398
+ elif self.gan_mode == 'nonsaturating':
399
+ if target_is_real:
400
+ loss = F.softplus(-prediction).view(bs, -1).mean(dim=1)
401
+ else:
402
+ loss = F.softplus(prediction).view(bs, -1).mean(dim=1)
403
+ return loss
404
+
405
+
406
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
407
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
408
+
409
+ Arguments:
410
+ netD (network) -- discriminator network
411
+ real_data (tensor array) -- real images
412
+ fake_data (tensor array) -- generated images from the generator
413
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
414
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
415
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
416
+ lambda_gp (float) -- weight for this loss
417
+
418
+ Returns the gradient penalty loss
419
+ """
420
+ if lambda_gp > 0.0:
421
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
422
+ interpolatesv = real_data
423
+ elif type == 'fake':
424
+ interpolatesv = fake_data
425
+ elif type == 'mixed':
426
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
427
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
428
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
429
+ else:
430
+ raise NotImplementedError('{} not implemented'.format(type))
431
+ interpolatesv.requires_grad_(True)
432
+ disc_interpolates = netD(interpolatesv)
433
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
434
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
435
+ create_graph=True, retain_graph=True, only_inputs=True)
436
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
437
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
438
+ return gradient_penalty, gradients
439
+ else:
440
+ return 0.0, None
441
+
442
+
443
+ class Normalize(nn.Module):
444
+
445
+ def __init__(self, power=2):
446
+ super(Normalize, self).__init__()
447
+ self.power = power
448
+
449
+ def forward(self, x, dim=1):
450
+ # norm = x.pow(self.power).sum(dim, keepdim=True).pow(1. / self.power)
451
+ # out = x.div(norm + 1e-7)
452
+ # FDL: To avoid sqrting 0s, which causes nans in grad
453
+ norm = (x + 1e-7).pow(self.power).sum(dim, keepdim=True).pow(1. / self.power)
454
+ out = x.div(norm)
455
+ return out
456
+
457
+
458
+ class PoolingF(nn.Module):
459
+ def __init__(self):
460
+ super(PoolingF, self).__init__()
461
+ model = [nn.AdaptiveMaxPool2d(1)]
462
+ self.model = nn.Sequential(*model)
463
+ self.l2norm = Normalize(2)
464
+
465
+ def forward(self, x):
466
+ return self.l2norm(self.model(x))
467
+
468
+
469
+ class ReshapeF(nn.Module):
470
+ def __init__(self):
471
+ super(ReshapeF, self).__init__()
472
+ model = [nn.AdaptiveAvgPool2d(4)]
473
+ self.model = nn.Sequential(*model)
474
+ self.l2norm = Normalize(2)
475
+
476
+ def forward(self, x):
477
+ x = self.model(x)
478
+ x_reshape = x.permute(0, 2, 3, 1).flatten(0, 2)
479
+ return self.l2norm(x_reshape)
480
+
481
+
482
+ class StridedConvF(nn.Module):
483
+ def __init__(self, init_type='normal', init_gain=0.02, gpu_ids=[]):
484
+ super().__init__()
485
+ # self.conv1 = nn.Conv2d(256, 128, 3, stride=2)
486
+ # self.conv2 = nn.Conv2d(128, 64, 3, stride=1)
487
+ self.l2_norm = Normalize(2)
488
+ self.mlps = {}
489
+ self.moving_averages = {}
490
+ self.init_type = init_type
491
+ self.init_gain = init_gain
492
+ self.gpu_ids = gpu_ids
493
+
494
+ def create_mlp(self, x):
495
+ C, H = x.shape[1], x.shape[2]
496
+ n_down = int(np.rint(np.log2(H / 32)))
497
+ mlp = []
498
+ for i in range(n_down):
499
+ mlp.append(nn.Conv2d(C, max(C // 2, 64), 3, stride=2))
500
+ mlp.append(nn.ReLU())
501
+ C = max(C // 2, 64)
502
+ mlp.append(nn.Conv2d(C, 64, 3))
503
+ mlp = nn.Sequential(*mlp)
504
+ init_net(mlp, self.init_type, self.init_gain, self.gpu_ids)
505
+ return mlp
506
+
507
+ def update_moving_average(self, key, x):
508
+ if key not in self.moving_averages:
509
+ self.moving_averages[key] = x.detach()
510
+
511
+ self.moving_averages[key] = self.moving_averages[key] * 0.999 + x.detach() * 0.001
512
+
513
+ def forward(self, x, use_instance_norm=False):
514
+ C, H = x.shape[1], x.shape[2]
515
+ key = '%d_%d' % (C, H)
516
+ if key not in self.mlps:
517
+ self.mlps[key] = self.create_mlp(x)
518
+ self.add_module("child_%s" % key, self.mlps[key])
519
+ mlp = self.mlps[key]
520
+ x = mlp(x)
521
+ self.update_moving_average(key, x)
522
+ x = x - self.moving_averages[key]
523
+ if use_instance_norm:
524
+ x = F.instance_norm(x)
525
+ return self.l2_norm(x)
526
+
527
+
528
+ class PatchSampleF(nn.Module):
529
+ def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[], opt=None):
530
+ # potential issues: currently, we use the same patch_ids for multiple images in the batch
531
+ super(PatchSampleF, self).__init__()
532
+ self.l2norm = Normalize(2)
533
+ self.use_mlp = use_mlp
534
+ self.nc = nc # hard-coded
535
+ self.mlp_init = False
536
+ self.init_type = init_type
537
+ self.init_gain = init_gain
538
+ self.gpu_ids = gpu_ids
539
+ self.opt = opt
540
+
541
+ def create_mlp(self, feats):
542
+ for mlp_id, feat in enumerate(feats):
543
+ input_nc = feat.shape[1]
544
+ mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
545
+ if len(self.gpu_ids) > 0:
546
+ mlp.cuda()
547
+ setattr(self, 'mlp_%d' % mlp_id, mlp)
548
+ init_net(self, self.init_type, self.init_gain, self.gpu_ids)
549
+ self.mlp_init = True
550
+
551
+ def forward(self, feats, num_patches=64, patch_ids=None):
552
+ return_ids = []
553
+ return_feats = []
554
+ if self.use_mlp and not self.mlp_init:
555
+ self.create_mlp(feats)
556
+ for feat_id, feat in enumerate(feats):
557
+ B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
558
+ feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
559
+ if num_patches > 0:
560
+ if patch_ids is not None:
561
+ patch_id = patch_ids[feat_id]
562
+ else:
563
+ # torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83
564
+ #patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
565
+ patch_id = np.random.permutation(feat_reshape.shape[1])
566
+ patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
567
+ patch_id = torch.tensor(patch_id, dtype=torch.long, device=feat.device)
568
+ x_sample = feat_reshape[:, patch_id, :].flatten(0, 1)
569
+ else:
570
+ x_sample = feat_reshape.flatten(0, 1)
571
+ patch_id = []
572
+ if self.use_mlp:
573
+ mlp = getattr(self, 'mlp_%d' % feat_id)
574
+ x_sample = mlp(x_sample)
575
+ return_ids.append(patch_id)
576
+ x_sample = self.l2norm(x_sample)
577
+ if num_patches == 0:
578
+ x_sample = x_sample.reshape([B, H, W, x_sample.shape[-1]]).permute(0, 3, 1, 2)
579
+ return_feats.append(x_sample)
580
+ return return_feats, return_ids
581
+
582
+
583
+ class G_Resnet(nn.Module):
584
+ def __init__(self, input_nc, output_nc, nz, num_downs, n_res, ngf=64,
585
+ norm=None, nl_layer=None):
586
+ super(G_Resnet, self).__init__()
587
+ n_downsample = num_downs
588
+ pad_type = 'reflect'
589
+ self.enc_content = ContentEncoder(n_downsample, n_res, input_nc, ngf, norm, nl_layer, pad_type=pad_type)
590
+ if nz == 0:
591
+ self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
592
+ else:
593
+ self.dec = Decoder_all(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
594
+
595
+ def decode(self, content, style=None):
596
+ return self.dec(content, style)
597
+
598
+ def forward(self, image, style=None, nce_layers=[], encode_only=False):
599
+ content, feats = self.enc_content(image, nce_layers=nce_layers, encode_only=encode_only)
600
+ if encode_only:
601
+ return feats
602
+ else:
603
+ images_recon = self.decode(content, style)
604
+ if len(nce_layers) > 0:
605
+ return images_recon, feats
606
+ else:
607
+ return images_recon
608
+
609
+ ##################################################################################
610
+ # Encoder and Decoders
611
+ ##################################################################################
612
+
613
+
614
+ class E_adaIN(nn.Module):
615
+ def __init__(self, input_nc, output_nc=1, nef=64, n_layers=4,
616
+ norm=None, nl_layer=None, vae=False):
617
+ # style encoder
618
+ super(E_adaIN, self).__init__()
619
+ self.enc_style = StyleEncoder(n_layers, input_nc, nef, output_nc, norm='none', activ='relu', vae=vae)
620
+
621
+ def forward(self, image):
622
+ style = self.enc_style(image)
623
+ return style
624
+
625
+
626
+ class StyleEncoder(nn.Module):
627
+ def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, vae=False):
628
+ super(StyleEncoder, self).__init__()
629
+ self.vae = vae
630
+ self.model = []
631
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
632
+ for i in range(2):
633
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
634
+ dim *= 2
635
+ for i in range(n_downsample - 2):
636
+ self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
637
+ self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
638
+ if self.vae:
639
+ self.fc_mean = nn.Linear(dim, style_dim) # , 1, 1, 0)
640
+ self.fc_var = nn.Linear(dim, style_dim) # , 1, 1, 0)
641
+ else:
642
+ self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
643
+
644
+ self.model = nn.Sequential(*self.model)
645
+ self.output_dim = dim
646
+
647
+ def forward(self, x):
648
+ if self.vae:
649
+ output = self.model(x)
650
+ output = output.view(x.size(0), -1)
651
+ output_mean = self.fc_mean(output)
652
+ output_var = self.fc_var(output)
653
+ return output_mean, output_var
654
+ else:
655
+ return self.model(x).view(x.size(0), -1)
656
+
657
+
658
+ class ContentEncoder(nn.Module):
659
+ def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type='zero'):
660
+ super(ContentEncoder, self).__init__()
661
+ self.model = []
662
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
663
+ # downsampling blocks
664
+ for i in range(n_downsample):
665
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
666
+ dim *= 2
667
+ # residual blocks
668
+ self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
669
+ self.model = nn.Sequential(*self.model)
670
+ self.output_dim = dim
671
+
672
+ def forward(self, x, nce_layers=[], encode_only=False):
673
+ if len(nce_layers) > 0:
674
+ feat = x
675
+ feats = []
676
+ for layer_id, layer in enumerate(self.model):
677
+ feat = layer(feat)
678
+ if layer_id in nce_layers:
679
+ feats.append(feat)
680
+ if layer_id == nce_layers[-1] and encode_only:
681
+ return None, feats
682
+ return feat, feats
683
+ else:
684
+ return self.model(x), None
685
+
686
+
687
+
688
+ class Decoder_all(nn.Module):
689
+ def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
690
+ super(Decoder_all, self).__init__()
691
+ # AdaIN residual blocks
692
+ self.resnet_block = ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)
693
+ self.n_blocks = 0
694
+ # upsampling blocks
695
+ for i in range(n_upsample):
696
+ block = [Upsample2(scale_factor=2), Conv2dBlock(dim + nz, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
697
+ setattr(self, 'block_{:d}'.format(self.n_blocks), nn.Sequential(*block))
698
+ self.n_blocks += 1
699
+ dim //= 2
700
+ # use reflection padding in the last conv layer
701
+ setattr(self, 'block_{:d}'.format(self.n_blocks), Conv2dBlock(dim + nz, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect'))
702
+ self.n_blocks += 1
703
+
704
+ def forward(self, x, y=None):
705
+ if y is not None:
706
+ output = self.resnet_block(cat_feature(x, y))
707
+ for n in range(self.n_blocks):
708
+ block = getattr(self, 'block_{:d}'.format(n))
709
+ if n > 0:
710
+ output = block(cat_feature(output, y))
711
+ else:
712
+ output = block(output)
713
+ return output
714
+
715
+
716
+ class Decoder(nn.Module):
717
+ def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
718
+ super(Decoder, self).__init__()
719
+
720
+ self.model = []
721
+ # AdaIN residual blocks
722
+ self.model += [ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)]
723
+ # upsampling blocks
724
+ for i in range(n_upsample):
725
+ if i == 0:
726
+ input_dim = dim + nz
727
+ else:
728
+ input_dim = dim
729
+ self.model += [Upsample2(scale_factor=2), Conv2dBlock(input_dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
730
+ dim //= 2
731
+ # use reflection padding in the last conv layer
732
+ self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')]
733
+ self.model = nn.Sequential(*self.model)
734
+
735
+ def forward(self, x, y=None):
736
+ if y is not None:
737
+ return self.model(cat_feature(x, y))
738
+ else:
739
+ return self.model(x)
740
+
741
+ ##################################################################################
742
+ # Sequential Models
743
+ ##################################################################################
744
+
745
+
746
+ class ResBlocks(nn.Module):
747
+ def __init__(self, num_blocks, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
748
+ super(ResBlocks, self).__init__()
749
+ self.model = []
750
+ for i in range(num_blocks):
751
+ self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, nz=nz)]
752
+ self.model = nn.Sequential(*self.model)
753
+
754
+ def forward(self, x):
755
+ return self.model(x)
756
+
757
+
758
+ ##################################################################################
759
+ # Basic Blocks
760
+ ##################################################################################
761
+ def cat_feature(x, y):
762
+ y_expand = y.view(y.size(0), y.size(1), 1, 1).expand(
763
+ y.size(0), y.size(1), x.size(2), x.size(3))
764
+ x_cat = torch.cat([x, y_expand], 1)
765
+ return x_cat
766
+
767
+
768
+ class ResBlock(nn.Module):
769
+ def __init__(self, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
770
+ super(ResBlock, self).__init__()
771
+
772
+ model = []
773
+ model += [Conv2dBlock(dim + nz, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
774
+ model += [Conv2dBlock(dim, dim + nz, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
775
+ self.model = nn.Sequential(*model)
776
+
777
+ def forward(self, x):
778
+ residual = x
779
+ out = self.model(x)
780
+ out += residual
781
+ return out
782
+
783
+
784
+ class Conv2dBlock(nn.Module):
785
+ def __init__(self, input_dim, output_dim, kernel_size, stride,
786
+ padding=0, norm='none', activation='relu', pad_type='zero'):
787
+ super(Conv2dBlock, self).__init__()
788
+ self.use_bias = True
789
+ # initialize padding
790
+ if pad_type == 'reflect':
791
+ self.pad = nn.ReflectionPad2d(padding)
792
+ elif pad_type == 'zero':
793
+ self.pad = nn.ZeroPad2d(padding)
794
+ else:
795
+ assert 0, "Unsupported padding type: {}".format(pad_type)
796
+
797
+ # initialize normalization
798
+ norm_dim = output_dim
799
+ if norm == 'batch':
800
+ self.norm = nn.BatchNorm2d(norm_dim)
801
+ elif norm == 'inst':
802
+ self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False)
803
+ elif norm == 'ln':
804
+ self.norm = LayerNorm(norm_dim)
805
+ elif norm == 'none':
806
+ self.norm = None
807
+ else:
808
+ assert 0, "Unsupported normalization: {}".format(norm)
809
+
810
+ # initialize activation
811
+ if activation == 'relu':
812
+ self.activation = nn.ReLU(inplace=True)
813
+ elif activation == 'lrelu':
814
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
815
+ elif activation == 'prelu':
816
+ self.activation = nn.PReLU()
817
+ elif activation == 'selu':
818
+ self.activation = nn.SELU(inplace=True)
819
+ elif activation == 'tanh':
820
+ self.activation = nn.Tanh()
821
+ elif activation == 'none':
822
+ self.activation = None
823
+ else:
824
+ assert 0, "Unsupported activation: {}".format(activation)
825
+
826
+ # initialize convolution
827
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
828
+
829
+ def forward(self, x):
830
+ x = self.conv(self.pad(x))
831
+ if self.norm:
832
+ x = self.norm(x)
833
+ if self.activation:
834
+ x = self.activation(x)
835
+ return x
836
+
837
+
838
+ class LinearBlock(nn.Module):
839
+ def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
840
+ super(LinearBlock, self).__init__()
841
+ use_bias = True
842
+ # initialize fully connected layer
843
+ self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
844
+
845
+ # initialize normalization
846
+ norm_dim = output_dim
847
+ if norm == 'batch':
848
+ self.norm = nn.BatchNorm1d(norm_dim)
849
+ elif norm == 'inst':
850
+ self.norm = nn.InstanceNorm1d(norm_dim)
851
+ elif norm == 'ln':
852
+ self.norm = LayerNorm(norm_dim)
853
+ elif norm == 'none':
854
+ self.norm = None
855
+ else:
856
+ assert 0, "Unsupported normalization: {}".format(norm)
857
+
858
+ # initialize activation
859
+ if activation == 'relu':
860
+ self.activation = nn.ReLU(inplace=True)
861
+ elif activation == 'lrelu':
862
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
863
+ elif activation == 'prelu':
864
+ self.activation = nn.PReLU()
865
+ elif activation == 'selu':
866
+ self.activation = nn.SELU(inplace=True)
867
+ elif activation == 'tanh':
868
+ self.activation = nn.Tanh()
869
+ elif activation == 'none':
870
+ self.activation = None
871
+ else:
872
+ assert 0, "Unsupported activation: {}".format(activation)
873
+
874
+ def forward(self, x):
875
+ out = self.fc(x)
876
+ if self.norm:
877
+ out = self.norm(out)
878
+ if self.activation:
879
+ out = self.activation(out)
880
+ return out
881
+
882
+ ##################################################################################
883
+ # Normalization layers
884
+ ##################################################################################
885
+
886
+
887
+ class LayerNorm(nn.Module):
888
+ def __init__(self, num_features, eps=1e-5, affine=True):
889
+ super(LayerNorm, self).__init__()
890
+ self.num_features = num_features
891
+ self.affine = affine
892
+ self.eps = eps
893
+
894
+ if self.affine:
895
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
896
+ self.beta = nn.Parameter(torch.zeros(num_features))
897
+
898
+ def forward(self, x):
899
+ shape = [-1] + [1] * (x.dim() - 1)
900
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
901
+ std = x.view(x.size(0), -1).std(1).view(*shape)
902
+ x = (x - mean) / (std + self.eps)
903
+
904
+ if self.affine:
905
+ shape = [1, -1] + [1] * (x.dim() - 2)
906
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
907
+ return x
908
+
909
+
910
+ class ResnetGenerator(nn.Module):
911
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
912
+
913
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
914
+ """
915
+
916
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
917
+ """Construct a Resnet-based generator
918
+
919
+ Parameters:
920
+ input_nc (int) -- the number of channels in input images
921
+ output_nc (int) -- the number of channels in output images
922
+ ngf (int) -- the number of filters in the last conv layer
923
+ norm_layer -- normalization layer
924
+ use_dropout (bool) -- if use dropout layers
925
+ n_blocks (int) -- the number of ResNet blocks
926
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
927
+ """
928
+ assert(n_blocks >= 0)
929
+ super(ResnetGenerator, self).__init__()
930
+ self.opt = opt
931
+ if type(norm_layer) == functools.partial:
932
+ use_bias = norm_layer.func == nn.InstanceNorm2d
933
+ else:
934
+ use_bias = norm_layer == nn.InstanceNorm2d
935
+
936
+ if opt.weight_norm == 'spectral':
937
+ weight_norm = nn.utils.spectral_norm
938
+ else:
939
+ def weight_norm(x): return x
940
+
941
+ model = [nn.ReflectionPad2d(3),
942
+ weight_norm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias)),
943
+ norm_layer(ngf),
944
+ nn.ReLU(True)]
945
+
946
+ n_downsampling = getattr(opt, 'n_downsampling', 2)
947
+ for i in range(n_downsampling): # add downsampling layers
948
+ mult = 2 ** i
949
+ if(no_antialias):
950
+ model += [weight_norm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)),
951
+ norm_layer(ngf * mult * 2),
952
+ nn.ReLU(True)]
953
+ else:
954
+ model += [weight_norm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)),
955
+ norm_layer(ngf * mult * 2),
956
+ nn.ReLU(True),
957
+ Downsample(ngf * mult * 2)]
958
+
959
+ mult = 2 ** n_downsampling
960
+ for i in range(n_blocks): # add ResNet blocks
961
+ extra = None
962
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, opt=opt)]
963
+
964
+ for i in range(n_downsampling): # add upsampling layers
965
+ mult = 2 ** (n_downsampling - i)
966
+ if no_antialias_up:
967
+ model += [weight_norm(nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
968
+ kernel_size=3, stride=2,
969
+ padding=1, output_padding=1,
970
+ bias=use_bias)),
971
+ norm_layer(int(ngf * mult / 2)),
972
+ nn.ReLU(True)]
973
+ else:
974
+ model += [Upsample(ngf * mult),
975
+ weight_norm(nn.Conv2d(ngf * mult, int(ngf * mult / 2),
976
+ kernel_size=3, stride=1,
977
+ padding=1, # output_padding=1,
978
+ bias=use_bias)),
979
+ norm_layer(int(ngf * mult / 2)),
980
+ nn.ReLU(True)]
981
+ model += [nn.ReflectionPad2d(3)]
982
+ model += [weight_norm(nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0))]
983
+ model += [nn.Tanh()]
984
+
985
+ self.model = nn.Sequential(*model)
986
+
987
+ def forward(self, input, layers=[], encode_only=False):
988
+ if -1 in layers:
989
+ layers.append(len(self.model))
990
+ if len(layers) > 0:
991
+ feat = input
992
+ feats = []
993
+ for layer_id, layer in enumerate(self.model):
994
+ # print(layer_id, layer)
995
+ feat = layer(feat)
996
+ if layer_id in layers:
997
+ # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
998
+ feats.append(feat)
999
+ else:
1000
+ # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
1001
+ pass
1002
+ if layer_id == layers[-1] and encode_only:
1003
+ # print('encoder only return features')
1004
+ return feats # return intermediate features alone; stop in the last layers
1005
+
1006
+ return feat, feats # return both output and intermediate features
1007
+ else:
1008
+ """Standard forward"""
1009
+ fake = self.model(input)
1010
+ return fake
1011
+
1012
+
1013
+ class ResnetDecoder(nn.Module):
1014
+ """Resnet-based decoder that consists of a few Resnet blocks + a few upsampling operations.
1015
+ """
1016
+
1017
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
1018
+ """Construct a Resnet-based decoder
1019
+
1020
+ Parameters:
1021
+ input_nc (int) -- the number of channels in input images
1022
+ output_nc (int) -- the number of channels in output images
1023
+ ngf (int) -- the number of filters in the last conv layer
1024
+ norm_layer -- normalization layer
1025
+ use_dropout (bool) -- if use dropout layers
1026
+ n_blocks (int) -- the number of ResNet blocks
1027
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1028
+ """
1029
+ assert(n_blocks >= 0)
1030
+ super(ResnetDecoder, self).__init__()
1031
+ if type(norm_layer) == functools.partial:
1032
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1033
+ else:
1034
+ use_bias = norm_layer == nn.InstanceNorm2d
1035
+ model = []
1036
+ n_downsampling = 2
1037
+ mult = 2 ** n_downsampling
1038
+ for i in range(n_blocks): # add ResNet blocks
1039
+
1040
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1041
+
1042
+ for i in range(n_downsampling): # add upsampling layers
1043
+ mult = 2 ** (n_downsampling - i)
1044
+ if(no_antialias):
1045
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
1046
+ kernel_size=3, stride=2,
1047
+ padding=1, output_padding=1,
1048
+ bias=use_bias),
1049
+ norm_layer(int(ngf * mult / 2)),
1050
+ nn.ReLU(True)]
1051
+ else:
1052
+ model += [Upsample(ngf * mult),
1053
+ nn.Conv2d(ngf * mult, int(ngf * mult / 2),
1054
+ kernel_size=3, stride=1,
1055
+ padding=1,
1056
+ bias=use_bias),
1057
+ norm_layer(int(ngf * mult / 2)),
1058
+ nn.ReLU(True)]
1059
+ model += [nn.ReflectionPad2d(3)]
1060
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
1061
+ model += [nn.Tanh()]
1062
+
1063
+ self.model = nn.Sequential(*model)
1064
+
1065
+ def forward(self, input):
1066
+ """Standard forward"""
1067
+ return self.model(input)
1068
+
1069
+
1070
+ class ResnetEncoder(nn.Module):
1071
+ """Resnet-based encoder that consists of a few downsampling + several Resnet blocks
1072
+ """
1073
+
1074
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
1075
+ """Construct a Resnet-based encoder
1076
+
1077
+ Parameters:
1078
+ input_nc (int) -- the number of channels in input images
1079
+ output_nc (int) -- the number of channels in output images
1080
+ ngf (int) -- the number of filters in the last conv layer
1081
+ norm_layer -- normalization layer
1082
+ use_dropout (bool) -- if use dropout layers
1083
+ n_blocks (int) -- the number of ResNet blocks
1084
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1085
+ """
1086
+ assert(n_blocks >= 0)
1087
+ super(ResnetEncoder, self).__init__()
1088
+ if type(norm_layer) == functools.partial:
1089
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1090
+ else:
1091
+ use_bias = norm_layer == nn.InstanceNorm2d
1092
+
1093
+ model = [nn.ReflectionPad2d(3),
1094
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
1095
+ norm_layer(ngf),
1096
+ nn.ReLU(True)]
1097
+
1098
+ n_downsampling = 2
1099
+ for i in range(n_downsampling): # add downsampling layers
1100
+ mult = 2 ** i
1101
+ if(no_antialias):
1102
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
1103
+ norm_layer(ngf * mult * 2),
1104
+ nn.ReLU(True)]
1105
+ else:
1106
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
1107
+ norm_layer(ngf * mult * 2),
1108
+ nn.ReLU(True),
1109
+ Downsample(ngf * mult * 2)]
1110
+
1111
+ mult = 2 ** n_downsampling
1112
+ for i in range(n_blocks): # add ResNet blocks
1113
+
1114
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1115
+
1116
+ self.model = nn.Sequential(*model)
1117
+
1118
+ def forward(self, input):
1119
+ """Standard forward"""
1120
+ return self.model(input)
1121
+
1122
+
1123
+ class ResnetBlock(nn.Module):
1124
+ """Define a Resnet block"""
1125
+
1126
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, opt=None):
1127
+ """Initialize the Resnet block
1128
+
1129
+ A resnet block is a conv block with skip connections
1130
+ We construct a conv block with build_conv_block function,
1131
+ and implement skip connections in <forward> function.
1132
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
1133
+ """
1134
+ super(ResnetBlock, self).__init__()
1135
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, opt)
1136
+
1137
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, opt=None):
1138
+ """Construct a convolutional block.
1139
+
1140
+ Parameters:
1141
+ dim (int) -- the number of channels in the conv layer.
1142
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
1143
+ norm_layer -- normalization layer
1144
+ use_dropout (bool) -- if use dropout layers.
1145
+ use_bias (bool) -- if the conv layer uses bias or not
1146
+
1147
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
1148
+ """
1149
+ conv_block = []
1150
+ p = 0
1151
+ if padding_type == 'reflect':
1152
+ conv_block += [nn.ReflectionPad2d(1)]
1153
+ elif padding_type == 'replicate':
1154
+ conv_block += [nn.ReplicationPad2d(1)]
1155
+ elif padding_type == 'zero':
1156
+ p = 1
1157
+ else:
1158
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
1159
+
1160
+ if opt.weight_norm == 'spectral':
1161
+ weight_norm = nn.utils.spectral_norm
1162
+ else:
1163
+ def weight_norm(x): return x
1164
+
1165
+ conv_block += [weight_norm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)), norm_layer(dim), nn.ReLU(True)]
1166
+ if use_dropout:
1167
+ conv_block += [nn.Dropout(0.5)]
1168
+
1169
+ p = 0
1170
+ if padding_type == 'reflect':
1171
+ conv_block += [nn.ReflectionPad2d(1)]
1172
+ elif padding_type == 'replicate':
1173
+ conv_block += [nn.ReplicationPad2d(1)]
1174
+ elif padding_type == 'zero':
1175
+ p = 1
1176
+ else:
1177
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
1178
+ conv_block += [weight_norm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)), norm_layer(dim)]
1179
+
1180
+ return nn.Sequential(*conv_block)
1181
+
1182
+ def forward(self, x):
1183
+ """Forward function (with skip connections)"""
1184
+ out = x + self.conv_block(x) # add skip connections
1185
+ return out
1186
+
1187
+
1188
+ class UnetGenerator(nn.Module):
1189
+ """Create a Unet-based generator"""
1190
+
1191
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
1192
+ """Construct a Unet generator
1193
+ Parameters:
1194
+ input_nc (int) -- the number of channels in input images
1195
+ output_nc (int) -- the number of channels in output images
1196
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
1197
+ image of size 128x128 will become of size 1x1 # at the bottleneck
1198
+ ngf (int) -- the number of filters in the last conv layer
1199
+ norm_layer -- normalization layer
1200
+
1201
+ We construct the U-Net from the innermost layer to the outermost layer.
1202
+ It is a recursive process.
1203
+ """
1204
+ super(UnetGenerator, self).__init__()
1205
+ # construct unet structure
1206
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
1207
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
1208
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
1209
+ # gradually reduce the number of filters from ngf * 8 to ngf
1210
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1211
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1212
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1213
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
1214
+
1215
+ def forward(self, input):
1216
+ """Standard forward"""
1217
+ return self.model(input)
1218
+
1219
+
1220
+ class UnetSkipConnectionBlock(nn.Module):
1221
+ """Defines the Unet submodule with skip connection.
1222
+ X -------------------identity----------------------
1223
+ |-- downsampling -- |submodule| -- upsampling --|
1224
+ """
1225
+
1226
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
1227
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
1228
+ """Construct a Unet submodule with skip connections.
1229
+
1230
+ Parameters:
1231
+ outer_nc (int) -- the number of filters in the outer conv layer
1232
+ inner_nc (int) -- the number of filters in the inner conv layer
1233
+ input_nc (int) -- the number of channels in input images/features
1234
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
1235
+ outermost (bool) -- if this module is the outermost module
1236
+ innermost (bool) -- if this module is the innermost module
1237
+ norm_layer -- normalization layer
1238
+ use_dropout (bool) -- if use dropout layers.
1239
+ """
1240
+ super(UnetSkipConnectionBlock, self).__init__()
1241
+ self.outermost = outermost
1242
+ if type(norm_layer) == functools.partial:
1243
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1244
+ else:
1245
+ use_bias = norm_layer == nn.InstanceNorm2d
1246
+ if input_nc is None:
1247
+ input_nc = outer_nc
1248
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
1249
+ stride=2, padding=1, bias=use_bias)
1250
+ downrelu = nn.LeakyReLU(0.2, True)
1251
+ downnorm = norm_layer(inner_nc)
1252
+ uprelu = nn.ReLU(True)
1253
+ upnorm = norm_layer(outer_nc)
1254
+
1255
+ if outermost:
1256
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1257
+ kernel_size=4, stride=2,
1258
+ padding=1)
1259
+ down = [downconv]
1260
+ up = [uprelu, upconv, nn.Tanh()]
1261
+ model = down + [submodule] + up
1262
+ elif innermost:
1263
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
1264
+ kernel_size=4, stride=2,
1265
+ padding=1, bias=use_bias)
1266
+ down = [downrelu, downconv]
1267
+ up = [uprelu, upconv, upnorm]
1268
+ model = down + up
1269
+ else:
1270
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1271
+ kernel_size=4, stride=2,
1272
+ padding=1, bias=use_bias)
1273
+ down = [downrelu, downconv, downnorm]
1274
+ up = [uprelu, upconv, upnorm]
1275
+
1276
+ if use_dropout:
1277
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
1278
+ else:
1279
+ model = down + [submodule] + up
1280
+
1281
+ self.model = nn.Sequential(*model)
1282
+
1283
+ def forward(self, x):
1284
+ if self.outermost:
1285
+ return self.model(x)
1286
+ else: # add skip connections
1287
+ return torch.cat([x, self.model(x)], 1)
1288
+
1289
+
1290
+ class NLayerDiscriminator(nn.Module):
1291
+ """Defines a PatchGAN discriminator"""
1292
+
1293
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False, opt=None):
1294
+ """Construct a PatchGAN discriminator
1295
+
1296
+ Parameters:
1297
+ input_nc (int) -- the number of channels in input images
1298
+ ndf (int) -- the number of filters in the last conv layer
1299
+ n_layers (int) -- the number of conv layers in the discriminator
1300
+ norm_layer -- normalization layer
1301
+ """
1302
+ super(NLayerDiscriminator, self).__init__()
1303
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1304
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1305
+ else:
1306
+ use_bias = norm_layer == nn.InstanceNorm2d
1307
+
1308
+ if opt.weight_norm == 'spectral':
1309
+ weight_norm = nn.utils.spectral_norm
1310
+ else:
1311
+ def weight_norm(x): return x
1312
+
1313
+ kw = 4
1314
+ padw = 1
1315
+ if(no_antialias):
1316
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
1317
+ else:
1318
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
1319
+ nf_mult = 1
1320
+ nf_mult_prev = 1
1321
+ for n in range(1, n_layers): # gradually increase the number of filters
1322
+ nf_mult_prev = nf_mult
1323
+ nf_mult = min(2 ** n, 8)
1324
+ if(no_antialias):
1325
+ sequence += [
1326
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1327
+ norm_layer(ndf * nf_mult),
1328
+ nn.LeakyReLU(0.2, True)
1329
+ ]
1330
+ else:
1331
+ sequence += [
1332
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
1333
+ norm_layer(ndf * nf_mult),
1334
+ nn.LeakyReLU(0.2, True),
1335
+ Downsample(ndf * nf_mult)
1336
+ ]
1337
+
1338
+ nf_mult_prev = nf_mult
1339
+ nf_mult = min(2 ** n_layers, 8)
1340
+ sequence += [
1341
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
1342
+ norm_layer(ndf * nf_mult),
1343
+ nn.LeakyReLU(0.2, True)
1344
+ ]
1345
+
1346
+ for i, layer in enumerate(sequence):
1347
+ if isinstance(layer, nn.Conv2d):
1348
+ sequence[i] = weight_norm(layer)
1349
+
1350
+ self.enc = nn.Sequential(*sequence)
1351
+ # output 1 channel prediction map
1352
+ self.final_conv = weight_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))
1353
+
1354
+
1355
+ def forward(self, input, labels=None):
1356
+ """Standard forward."""
1357
+ final_ft = self.enc(input)
1358
+ dout = self.final_conv(final_ft)
1359
+ return dout
1360
+
1361
+
1362
+ class PixelDiscriminator(nn.Module):
1363
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
1364
+
1365
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
1366
+ """Construct a 1x1 PatchGAN discriminator
1367
+
1368
+ Parameters:
1369
+ input_nc (int) -- the number of channels in input images
1370
+ ndf (int) -- the number of filters in the last conv layer
1371
+ norm_layer -- normalization layer
1372
+ """
1373
+ super(PixelDiscriminator, self).__init__()
1374
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1375
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1376
+ else:
1377
+ use_bias = norm_layer == nn.InstanceNorm2d
1378
+
1379
+ self.net = [
1380
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
1381
+ nn.LeakyReLU(0.2, True),
1382
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
1383
+ norm_layer(ndf * 2),
1384
+ nn.LeakyReLU(0.2, True),
1385
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
1386
+
1387
+ self.net = nn.Sequential(*self.net)
1388
+
1389
+ def forward(self, input):
1390
+ """Standard forward."""
1391
+ return self.net(input)
1392
+
1393
+
1394
+ class PatchDiscriminator(NLayerDiscriminator):
1395
+ """Defines a PatchGAN discriminator"""
1396
+
1397
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
1398
+ super().__init__(input_nc, ndf, 2, norm_layer, no_antialias)
1399
+
1400
+ def forward(self, input):
1401
+ B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
1402
+ size = 16
1403
+ Y = H // size
1404
+ X = W // size
1405
+ input = input.view(B, C, Y, size, X, size)
1406
+ input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
1407
+ return super().forward(input)
1408
+
1409
+
1410
+ class GroupedChannelNorm(nn.Module):
1411
+ def __init__(self, num_groups):
1412
+ super().__init__()
1413
+ self.num_groups = num_groups
1414
+
1415
+ def forward(self, x):
1416
+ shape = list(x.shape)
1417
+ new_shape = [shape[0], self.num_groups, shape[1] // self.num_groups] + shape[2:]
1418
+ x = x.view(*new_shape)
1419
+ mean = x.mean(dim=2, keepdim=True)
1420
+ std = x.std(dim=2, keepdim=True)
1421
+ x_norm = (x - mean) / (std + 1e-7)
1422
+ return x_norm.view(*shape)
asp/models/patchnce.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class PatchNCELoss(nn.Module):
7
+ def __init__(self, opt):
8
+ super().__init__()
9
+ self.opt = opt
10
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
11
+ self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
12
+
13
+ def forward(self, feat_q, feat_k):
14
+ num_patches = feat_q.shape[0]
15
+ dim = feat_q.shape[1]
16
+ feat_k = feat_k.detach()
17
+
18
+ # pos logit
19
+ l_pos = torch.bmm(
20
+ feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
21
+ l_pos = l_pos.view(num_patches, 1)
22
+
23
+ # neg logit
24
+
25
+ # Should the negatives from the other samples of a minibatch be utilized?
26
+ # In CUT and FastCUT, we found that it's best to only include negatives
27
+ # from the same image. Therefore, we set
28
+ # --nce_includes_all_negatives_from_minibatch as False
29
+ # However, for single-image translation, the minibatch consists of
30
+ # crops from the "same" high-resolution image.
31
+ # Therefore, we will include the negatives from the entire minibatch.
32
+ if self.opt.nce_includes_all_negatives_from_minibatch:
33
+ # reshape features as if they are all negatives of minibatch of size 1.
34
+ batch_dim_for_bmm = 1
35
+ else:
36
+ batch_dim_for_bmm = self.opt.batch_size
37
+
38
+ # reshape features to batch size
39
+ feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
40
+ feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
41
+ npatches = feat_q.size(1)
42
+ l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
43
+
44
+ # diagonal entries are similarity between same features, and hence meaningless.
45
+ # just fill the diagonal with very small number, which is exp(-10) and almost zero
46
+ diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
47
+ l_neg_curbatch.masked_fill_(diagonal, -10.0)
48
+ l_neg = l_neg_curbatch.view(-1, npatches)
49
+
50
+ out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
51
+
52
+ loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
53
+ device=feat_q.device))
54
+
55
+ return loss
asp/options/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
asp/options/base_options.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from util import util
4
+ import torch
5
+ import models
6
+ import data
7
+
8
+
9
+ class BaseOptions():
10
+ """This class defines options used during both training and test time.
11
+
12
+ It also implements several helper functions such as parsing, printing, and saving the options.
13
+ It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
14
+ """
15
+
16
+ def __init__(self, cmd_line=None):
17
+ """Reset the class; indicates the class hasn't been initialized"""
18
+ self.initialized = False
19
+ self.cmd_line = None
20
+ if cmd_line is not None:
21
+ self.cmd_line = cmd_line.split()
22
+
23
+ def initialize(self, parser):
24
+ """Define the common options that are used in both training and test."""
25
+ # basic parameters
26
+ parser.add_argument('--dataroot', default='placeholder', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
27
+ parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
28
+ parser.add_argument('--easy_label', type=str, default='experiment_name', help='Interpretable name')
29
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
30
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
31
+ # model parameters
32
+ parser.add_argument('--model', type=str, default='cut', help='chooses which model to use.')
33
+ parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
34
+ parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
35
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
36
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
37
+ parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2', 'multi_d'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
38
+ parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks', 'resnet_6blocks', 'resnet_4blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'fdlresnet', 'fdlunet'], help='specify generator architecture')
39
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
40
+ parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
41
+ parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
42
+ parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization')
43
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
44
+ parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True,
45
+ help='no dropout for the generator')
46
+ parser.add_argument('--no_antialias', action='store_true', help='if specified, use stride=2 convs instead of antialiased-downsampling (sad)')
47
+ parser.add_argument('--no_antialias_up', action='store_true', help='if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]')
48
+ # dataset parameters
49
+ parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
50
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
51
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
52
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
53
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
54
+ parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
55
+ parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
56
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
57
+ parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
58
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
59
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
60
+ parser.add_argument('--random_scale_max', type=float, default=3.0,
61
+ help='(used for single image translation) Randomly scale the image by the specified factor as data augmentation.')
62
+ # additional parameters
63
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
64
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
65
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
66
+
67
+ # parameters related to StyleGAN2-based networks
68
+ parser.add_argument('--stylegan2_G_num_downsampling',
69
+ default=1, type=int,
70
+ help='Number of downsampling layers used by StyleGAN2Generator')
71
+
72
+ # FDL:
73
+ parser.add_argument('--weight_norm', type=str, default='none', choices=['none', 'spectral'], help='chooses which weight norm layer to use.')
74
+
75
+ self.initialized = True
76
+ return parser
77
+
78
+ def gather_options(self):
79
+ """Initialize our parser with basic options(only once).
80
+ Add additional model-specific and dataset-specific options.
81
+ These options are defined in the <modify_commandline_options> function
82
+ in model and dataset classes.
83
+ """
84
+ if not self.initialized: # check if it has been initialized
85
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
86
+ parser = self.initialize(parser)
87
+
88
+ # get the basic options
89
+ if self.cmd_line is None:
90
+ opt, _ = parser.parse_known_args()
91
+ else:
92
+ opt, _ = parser.parse_known_args(self.cmd_line)
93
+
94
+ # modify model-related parser options
95
+ model_name = opt.model
96
+ model_option_setter = models.get_option_setter(model_name)
97
+ parser = model_option_setter(parser, self.isTrain)
98
+ if self.cmd_line is None:
99
+ opt, _ = parser.parse_known_args() # parse again with new defaults
100
+ else:
101
+ opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
102
+
103
+ # modify dataset-related parser options
104
+ dataset_name = opt.dataset_mode
105
+ dataset_option_setter = data.get_option_setter(dataset_name)
106
+ parser = dataset_option_setter(parser, self.isTrain)
107
+
108
+ # save and return the parser
109
+ self.parser = parser
110
+ if self.cmd_line is None:
111
+ return parser.parse_args()
112
+ else:
113
+ return parser.parse_args(self.cmd_line)
114
+
115
+ def print_options(self, opt):
116
+ """Print and save options
117
+
118
+ It will print both current options and default values(if different).
119
+ It will save options into a text file / [checkpoints_dir] / opt.txt
120
+ """
121
+ message = ''
122
+ message += '----------------- Options ---------------\n'
123
+ for k, v in sorted(vars(opt).items()):
124
+ comment = ''
125
+ default = self.parser.get_default(k)
126
+ if v != default:
127
+ comment = '\t[default: %s]' % str(default)
128
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
129
+ message += '----------------- End -------------------'
130
+ print(message)
131
+
132
+ # save to the disk
133
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
134
+ util.mkdirs(expr_dir)
135
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
136
+ try:
137
+ with open(file_name, 'wt') as opt_file:
138
+ opt_file.write(message)
139
+ opt_file.write('\n')
140
+ except PermissionError as error:
141
+ print("permission error {}".format(error))
142
+ pass
143
+
144
+ def parse(self):
145
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
146
+ opt = self.gather_options()
147
+ opt.isTrain = self.isTrain # train or test
148
+
149
+ # process opt.suffix
150
+ if opt.suffix:
151
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
152
+ opt.name = opt.name + suffix
153
+
154
+ self.print_options(opt)
155
+
156
+ # set gpu ids
157
+ str_ids = opt.gpu_ids.split(',')
158
+ opt.gpu_ids = []
159
+ for str_id in str_ids:
160
+ id = int(str_id)
161
+ if id >= 0:
162
+ opt.gpu_ids.append(id)
163
+ if len(opt.gpu_ids) > 0:
164
+ torch.cuda.set_device(opt.gpu_ids[0])
165
+
166
+ self.opt = opt
167
+ return self.opt
asp/options/test_options.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TestOptions(BaseOptions):
5
+ """This class includes test options.
6
+
7
+ It also includes shared options defined in BaseOptions.
8
+ """
9
+
10
+ def initialize(self, parser):
11
+ parser = BaseOptions.initialize(self, parser) # define shared options
12
+ parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
13
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
14
+ # Dropout and Batchnorm has different behavioir during training and test.
15
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
16
+ parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
17
+
18
+ # To avoid cropping, the load_size should be the same as crop_size
19
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
20
+ self.isTrain = False
21
+ return parser
asp/options/train_options.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TrainOptions(BaseOptions):
5
+ """This class includes training options.
6
+
7
+ It also includes shared options defined in BaseOptions.
8
+ """
9
+
10
+ def initialize(self, parser):
11
+ parser = BaseOptions.initialize(self, parser)
12
+ # visdom and HTML visualization parameters
13
+ parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
14
+ parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
15
+ parser.add_argument('--display_id', type=int, default=None, help='window id of the web display. Default is random window id')
16
+ parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
17
+ parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
18
+ parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
19
+ parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
20
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
21
+ parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
22
+ # network saving and loading parameters
23
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
24
+ parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
25
+ parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
26
+ parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
27
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
28
+ parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
29
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
30
+ parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
31
+
32
+ # training parameters
33
+ parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs with the initial learning rate')
34
+ parser.add_argument('--n_epochs_decay', type=int, default=200, help='number of epochs to linearly decay learning rate to zero')
35
+ parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
36
+ parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
37
+ parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
38
+ parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
39
+ parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
40
+ parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
41
+ parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
42
+
43
+ self.isTrain = True
44
+ return parser
asp/util/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """This package includes a miscellaneous collection of useful helper functions."""
2
+ from asp.util import *
asp/util/fdlutil.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import os
3
+ import sys
4
+ from pylab import *
5
+ import matplotlib as mpl
6
+
7
+ # Use tkAgg when plotting to a window, Agg when to a file
8
+ # #### mpl.use('TkAgg') # Don't use this unless emergency. More trouble than it's worth
9
+ mpl.use('Agg')
10
+
11
+
12
+ def quick_imshow(nrows, ncols=1, images=None, titles=None, colorbar=True, colormap='jet',
13
+ vmax=None, vmin=None, figsize=None, figtitle=None, visibleaxis=True,
14
+ saveas='/home/ubuntu/tempimshow.png', tight=False, dpi=250.0):
15
+ """-------------------------------------------------------------------------
16
+ Desc.: convenience function that make subplots of imshow
17
+ Args.: nrows - number of rows
18
+ ncols - number of cols
19
+ images - list of images
20
+ titles - list of titles
21
+ vmax - tuple of vmax for the colormap. If scalar,
22
+ the same value is used for all subplots. If one
23
+ of the entries is None, no colormap for that
24
+ subplot will be drawn.
25
+ vmin - tuple of vmin
26
+ Returns: f - the figure handle
27
+ axes - axes or array of axes objects
28
+ caxes - tuple of axes image
29
+ -------------------------------------------------------------------------"""
30
+ if isinstance(nrows, np.ndarray):
31
+ images = nrows
32
+ nrows = 1
33
+ ncols = 1
34
+
35
+ if figsize == None:
36
+ # 1.0 translates to 100 pixels of the figure
37
+ s = 5.0
38
+ if figtitle:
39
+ figsize = (s * ncols, s * nrows + 0.5)
40
+ else:
41
+ figsize = (s * ncols, s * nrows)
42
+
43
+ if nrows == ncols == 1:
44
+ if isinstance(images, list):
45
+ images = images[0]
46
+ f, ax = plt.subplots(figsize=figsize)
47
+ cax = ax.imshow(images, cmap=colormap, vmax=vmax, vmin=vmin)
48
+ if colorbar:
49
+ f.colorbar(cax, ax=ax)
50
+ if titles != None:
51
+ ax.set_title(titles)
52
+ if figtitle != None:
53
+ f.suptitle(figtitle)
54
+ cax.axes.get_xaxis().set_visible(visibleaxis)
55
+ cax.axes.get_yaxis().set_visible(visibleaxis)
56
+ if tight:
57
+ plt.tight_layout()
58
+ if len(saveas) > 0:
59
+ dirname = os.path.dirname(saveas)
60
+ if not os.path.exists(dirname):
61
+ os.makedirs(dirname)
62
+ plt.savefig(saveas)
63
+ return f, ax, cax
64
+
65
+ f, axes = plt.subplots(nrows, ncols, figsize=figsize, dpi=dpi)
66
+ caxes = []
67
+ i = 0
68
+ for ax, img in zip(axes.flat, images):
69
+ if isinstance(vmax, tuple) and isinstance(vmin, tuple):
70
+ if vmax[i] is not None and vmin[i] is not None:
71
+ cax = ax.imshow(img, cmap=colormap, vmax=vmax[i], vmin=vmin[i])
72
+ else:
73
+ cax = ax.imshow(img, cmap=colormap)
74
+ elif isinstance(vmax, tuple) and vmin is None:
75
+ if vmax[i] is not None:
76
+ cax = ax.imshow(img, cmap=colormap, vmax=vmax[i], vmin=0)
77
+ else:
78
+ cax = ax.imshow(img, cmap=colormap)
79
+ elif vmax is None and vmin is None:
80
+ cax = ax.imshow(img, cmap=colormap)
81
+ else:
82
+ cax = ax.imshow(img, cmap=colormap, vmax=vmax, vmin=vmin)
83
+ if titles != None:
84
+ ax.set_title(titles[i])
85
+ if colorbar:
86
+ f.colorbar(cax, ax=ax)
87
+ caxes.append(cax)
88
+ cax.axes.get_xaxis().set_visible(visibleaxis)
89
+ cax.axes.get_yaxis().set_visible(visibleaxis)
90
+ i = i + 1
91
+ if figtitle != None:
92
+ f.suptitle(figtitle)
93
+ if tight:
94
+ plt.tight_layout()
95
+ if len(saveas) > 0:
96
+ dirname = os.path.dirname(saveas)
97
+ if not os.path.exists(dirname):
98
+ os.makedirs(dirname)
99
+ plt.savefig(saveas)
100
+ return f, axes, tuple(caxes)
101
+
102
+
103
+ def update_subplots(images, caxes, f=None, axes=None, indices=(), vmax=None,
104
+ vmin=None):
105
+ """-------------------------------------------------------------------------
106
+ Desc.: update subplots in a figure
107
+ Args.: images - new images to plot
108
+ caxes - caxes returned at figure creation
109
+ indices - specific indices of subplots to be updated
110
+ Returns:
111
+ -------------------------------------------------------------------------"""
112
+ for i in range(len(images)):
113
+ if len(indices) > 0:
114
+ ind = indices[i]
115
+ else:
116
+ ind = i
117
+ img = images[i]
118
+ caxes[ind].set_data(img)
119
+ cbar = caxes[ind].colorbar
120
+ if isinstance(vmax, tuple) and isinstance(vmin, tuple):
121
+ if vmax[i] is not None and vmin[i] is not None:
122
+ cbar.set_clim([vmin[i], vmax[i]])
123
+ else:
124
+ cbar.set_clim([img.min(), img.max()])
125
+ elif isinstance(vmax, tuple) and vmin is None:
126
+ if vmax[i] is not None:
127
+ cbar.set_clim([0, vmax[i]])
128
+ else:
129
+ cbar.set_clim([img.min(), img.max()])
130
+ elif vmax is None and vmin is None:
131
+ cbar.set_clim([img.min(), img.max()])
132
+ else:
133
+ cbar.set_clim([vmin, vmax])
134
+ cbar.update_normal(caxes[ind])
135
+ pause(0.01)
136
+ tight_layout()
137
+
138
+
139
+ def slide_show(image, dt=0.01, vmax=None, vmin=None):
140
+ """
141
+ Slide show for visualizing an image volume. Image is (w, h, d)
142
+ :param image: (w, h, d), slides are 2D images along the depth axis
143
+ :param dt:
144
+ :param vmax:
145
+ :param vmin:
146
+ :return:
147
+ """
148
+ if image.dtype == bool:
149
+ image *= 1.0
150
+ if vmax is None:
151
+ vmax = image.max()
152
+ if vmin is None:
153
+ vmin = image.min()
154
+ plt.ion()
155
+ plt.figure()
156
+ for i in range(image.shape[2]):
157
+ plt.cla()
158
+ cax = plt.imshow(image[:, :, i], cmap='jet', vmin=vmin, vmax=vmax)
159
+ plt.title(str('Slice: %i/%i' % (i, image.shape[2] - 1)))
160
+ if i == 0:
161
+ cf = plt.gcf()
162
+ ca = plt.gca()
163
+ cf.colorbar(cax, ax=ca)
164
+ plt.pause(dt)
165
+ plt.draw()
166
+
167
+
168
+ def quick_collage(images, nrows=3, ncols=2, normalize=False, figsize=(20.0, 10.0), figtitle=None, colorbar=True,
169
+ tight=True, saveas='/home/ubuntu/tempcollage.png'):
170
+ def zero_to_one(x):
171
+ if x.min() == x.max():
172
+ return x - x.min()
173
+ return (x.astype(float) - x.min()) / (x.max() - x.min())
174
+ # Normalize every image
175
+ if isinstance(images, np.ndarray):
176
+ images = [images]
177
+ # Check the shape and make sure everything is float
178
+ img_shp = images[0].shape
179
+ if normalize:
180
+ images = [zero_to_one(image) for image in images]
181
+ vmax, vmin = 1.0, 0.0
182
+ else:
183
+ vmax, vmin = max([img.max() for img in images]), min(
184
+ [img.min() for img in images])
185
+ # Highlight the boundaries
186
+ for i in range(0, len(images) - 1):
187
+ images[i] = np.hstack(
188
+ [images[i], np.full((img_shp[0], 1, img_shp[2]), np.nan)])
189
+ collage = np.hstack(images)
190
+ # Determine slice depth
191
+ depth = collage.shape[2]
192
+ n_slices = nrows * ncols
193
+ z = [int(depth / (n_slices + 1) * i - 1) for i in range(1, (n_slices + 1))]
194
+ titles = ['Slice %d/%d' % (i, depth) for i in z]
195
+ quick_imshow(
196
+ nrows, ncols,
197
+ [collage[:, :, z[i]] for i in range(n_slices)],
198
+ titles=titles,
199
+ figtitle=figtitle,
200
+ figsize=figsize,
201
+ vmax=vmax, vmin=vmin,
202
+ colorbar=colorbar, tight=tight)
203
+ if len(saveas) > 0:
204
+ plt.savefig(saveas)
205
+ plt.close()
206
+
207
+
208
+ def quick_plot(x_data, y_data=None, fmt='', color=None, xlim=None, ylim=None,
209
+ label='', legends=False, x_label='', y_label='', figtitle='', annotation=None, figsize=(20, 10),
210
+ f=None, ax=None, saveas=''):
211
+ if f is None or ax is None:
212
+ f, ax = subplots(figsize=figsize)
213
+ if y_data is None:
214
+ temp = x_data
215
+ x_data = list(range(len(temp)))
216
+ y_data = temp
217
+ ax.plot(x_data, y_data, fmt, label=label, color=color)
218
+ if xlim is not None:
219
+ ax.set_xlim(xlim)
220
+ if ylim is not None:
221
+ ax.set_ylim(ylim)
222
+ if annotation is not None:
223
+ for i in range(len(x_data)):
224
+ annotate(annotation[i], (x_data[i], y_data[i]),
225
+ textcoords='offset points', xytext=(0, 10), ha='center')
226
+ if len(x_label) > 0:
227
+ ax.set_xlabel(x_label)
228
+ if len(y_label) > 0:
229
+ ax.set_ylabel(y_label)
230
+ if len(figtitle) > 0:
231
+ f.suptitle(figtitle)
232
+ if legends:
233
+ ax.legend(loc='center left', bbox_to_anchor=(1.04, 0.5))
234
+ ax.grid()
235
+ if len(saveas) > 0:
236
+ f.savefig(saveas, bbox_inches='tight')
237
+ ax.grid()
238
+ return f, ax
239
+
240
+
241
+ def quick_scatter(x_data, y_data=None, xlim=None, ylim=None,
242
+ label='', legends=False, x_label='', y_label='', figtitle='', annotation=None,
243
+ f=None, ax=None, saveas=''):
244
+ if f is None or ax is None:
245
+ f, ax = subplots()
246
+ if y_data is None:
247
+ temp = x_data
248
+ x_data = list(range(len(temp)))
249
+ y_data = temp
250
+ ax.scatter(x_data, y_data, label=label)
251
+ if xlim is not None:
252
+ ax.set_xlim(xlim)
253
+ if ylim is not None:
254
+ ax.set_ylim(ylim)
255
+ if annotation is not None:
256
+ for i in range(len(x_data)):
257
+ annotate(annotation[i], (x_data[i], y_data[i]),
258
+ textcoords='offset points', xytext=(0, 10), ha='center')
259
+ if len(x_label) > 0:
260
+ ax.set_xlabel(x_label)
261
+ if len(y_label) > 0:
262
+ ax.set_ylabel(y_label)
263
+ if len(figtitle) > 0:
264
+ f.suptitle(figtitle)
265
+ if legends:
266
+ ax.legend()
267
+ ax.grid()
268
+ if len(saveas) > 0:
269
+ f.savefig(saveas)
270
+ return f, ax
271
+
272
+
273
+ def quick_load(file_path, fits_field=1):
274
+ if file_path.endswith('npz'):
275
+ with load(file_path, allow_pickle=True) as f:
276
+ data = f['arr_0']
277
+ # Take care of the case where a dictionary is saved in npz format
278
+ if isinstance(data, ndarray) and data.dtype == 'O':
279
+ data = data.flatten()[0]
280
+ # elif file_path.endswith(('pyc', 'pickle')):
281
+ # data = pickle_load(file_path)
282
+ # elif file_path.endswith('fits.gz'):
283
+ # data = read_fits_data(file_path, fits_field)
284
+ # elif file_path.endswith('h5'):
285
+ # data = read_hdf5_data(file_path)
286
+ else:
287
+ raise NotImplementedError(
288
+ "Only npz, pyc, h5 and fits.gz are supported!")
289
+ return data
290
+
291
+
292
+ def quick_save(file_path, data):
293
+ dir_name = os.path.dirname(file_path)
294
+ if not os.path.exists(dir_name):
295
+ os.makedirs(dir_name)
296
+ # For better disk utilization and compatibility with fits, use int32
297
+ if file_path.endswith('npz'):
298
+ savez_compressed(file_path, data)
299
+ # elif file_path.endswith(('pyc', 'pickle')):
300
+ # save_object(file_path, data)
301
+ # elif file_path.endswith('fits.gz'):
302
+ # if isinstance(data, ndarray) and data.dtype == int:
303
+ # data = data.astype(int32)
304
+ # save_fits_data(file_path, data)
305
+ # elif file_path.endswith('h5'):
306
+ # write_hdf5_data(file_path, data)
307
+ else:
308
+ raise NotImplementedError(
309
+ "Only npz, pyc, h5 and fits.gz are supported!")
310
+
311
+
312
+ def import_module(name, path):
313
+ """
314
+ correct way of importing a module dynamically in python 3.
315
+ :param name: name given to module instance.
316
+ :param path: path to module.
317
+ :return: module: returned module instance.
318
+ """
319
+ spec = importlib.util.spec_from_file_location(name, path)
320
+ module = importlib.util.module_from_spec(spec)
321
+ spec.loader.exec_module(module)
322
+ return module
323
+
324
+
325
+ def obj_from_dict(info, parent=None, default_args=None):
326
+ """Initialize an object from dict.
327
+ The dict must contain the key "type", which indicates the object type, it
328
+ can be either a string or type, such as "list" or ``list``. Remaining
329
+ fields are treated as the arguments for constructing the object.
330
+ Args:
331
+ info (dict): Object types and arguments.
332
+ parent (:class:`module`): Module which may containing expected object
333
+ classes.
334
+ default_args (dict, optional): Default arguments for initializing the
335
+ object.
336
+ Returns:
337
+ any type: Object built from the dict.
338
+ """
339
+ assert isinstance(info, dict) and 'type' in info
340
+ assert isinstance(default_args, dict) or default_args is None
341
+ args = info.copy()
342
+ obj_type = args.pop('type')
343
+ if isinstance(obj_type, str):
344
+ if parent is not None:
345
+ obj_type = getattr(parent, obj_type)
346
+ else:
347
+ obj_type = sys.modules[obj_type]
348
+ elif not isinstance(obj_type, type):
349
+ raise TypeError('type must be a str or valid type, but '
350
+ f'got {type(obj_type)}')
351
+ if default_args is not None:
352
+ for name, value in default_args.items():
353
+ args.setdefault(name, value)
354
+ return obj_type(**args)
355
+
356
+
357
+ def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
358
+ """
359
+ one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
360
+ :param image: nd image. can be anything
361
+ :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
362
+ len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
363
+ the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
364
+ Example:
365
+ image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
366
+ image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
367
+ :param mode: see np.pad for documentation
368
+ :param return_slicer: if True then this function will also return what coords you will need to use when cropping back
369
+ to original shape
370
+ :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
371
+ divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
372
+ be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
373
+ :param kwargs: see np.pad for documentation
374
+ """
375
+ if kwargs is None:
376
+ kwargs = {}
377
+
378
+ if new_shape is not None:
379
+ old_shape = np.array(image.shape[-len(new_shape):])
380
+ else:
381
+ assert shape_must_be_divisible_by is not None
382
+ assert isinstance(shape_must_be_divisible_by,
383
+ (list, tuple, np.ndarray))
384
+ new_shape = image.shape[-len(shape_must_be_divisible_by):]
385
+ old_shape = new_shape
386
+
387
+ num_axes_nopad = len(image.shape) - len(new_shape)
388
+
389
+ new_shape = [max(new_shape[i], old_shape[i])
390
+ for i in range(len(new_shape))]
391
+
392
+ if not isinstance(new_shape, np.ndarray):
393
+ new_shape = np.array(new_shape)
394
+
395
+ if shape_must_be_divisible_by is not None:
396
+ if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
397
+ shape_must_be_divisible_by = [
398
+ shape_must_be_divisible_by] * len(new_shape)
399
+ else:
400
+ assert len(shape_must_be_divisible_by) == len(new_shape)
401
+
402
+ for i in range(len(new_shape)):
403
+ if new_shape[i] % shape_must_be_divisible_by[i] == 0:
404
+ new_shape[i] -= shape_must_be_divisible_by[i]
405
+
406
+ new_shape = np.array(
407
+ [new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in
408
+ range(len(new_shape))])
409
+
410
+ difference = new_shape - old_shape
411
+ pad_below = difference // 2
412
+ pad_above = difference // 2 + difference % 2
413
+ pad_list = [[0, 0]] * num_axes_nopad + \
414
+ list([list(i) for i in zip(pad_below, pad_above)])
415
+ res = np.pad(image, pad_list, mode, **kwargs)
416
+ if not return_slicer:
417
+ return res
418
+ else:
419
+ pad_list = np.array(pad_list)
420
+ pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
421
+ slicer = list(slice(*i) for i in pad_list)
422
+ return res, slicer
asp/util/fid.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+ import os
35
+ import pathlib
36
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
37
+
38
+ import numpy as np
39
+ import torch
40
+ import torchvision.transforms as TF
41
+ from PIL import Image
42
+ from scipy import linalg
43
+ from torch.nn.functional import adaptive_avg_pool2d
44
+
45
+ try:
46
+ from tqdm import tqdm
47
+ except ImportError:
48
+ # If tqdm is not available, provide a mock version of it
49
+ def tqdm(x):
50
+ return x
51
+
52
+ from util.inception import InceptionV3
53
+
54
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
55
+ parser.add_argument('--batch-size', type=int, default=50,
56
+ help='Batch size to use')
57
+ parser.add_argument('--num-workers', type=int,
58
+ help=('Number of processes to use for data loading. '
59
+ 'Defaults to `min(8, num_cpus)`'))
60
+ parser.add_argument('--device', type=str, default=None,
61
+ help='Device to use. Like cuda, cuda:0 or cpu')
62
+ parser.add_argument('--dims', type=int, default=2048,
63
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
64
+ help=('Dimensionality of Inception features to use. '
65
+ 'By default, uses pool3 features'))
66
+ parser.add_argument('path', type=str, nargs=2,
67
+ help=('Paths to the generated images or '
68
+ 'to .npz statistic files'))
69
+
70
+ IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
71
+ 'tif', 'tiff', 'webp'}
72
+
73
+
74
+ class ImagePathDataset(torch.utils.data.Dataset):
75
+ def __init__(self, files, transforms=None):
76
+ self.files = files
77
+ self.transforms = transforms
78
+
79
+ def __len__(self):
80
+ return len(self.files)
81
+
82
+ def __getitem__(self, i):
83
+ path = self.files[i]
84
+ img = Image.open(path).convert('RGB')
85
+ if self.transforms is not None:
86
+ img = self.transforms(img)
87
+ return img
88
+
89
+
90
+ def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
91
+ num_workers=1):
92
+ """Calculates the activations of the pool_3 layer for all images.
93
+
94
+ Params:
95
+ -- files : List of image files paths
96
+ -- model : Instance of inception model
97
+ -- batch_size : Batch size of images for the model to process at once.
98
+ Make sure that the number of samples is a multiple of
99
+ the batch size, otherwise some samples are ignored. This
100
+ behavior is retained to match the original FID score
101
+ implementation.
102
+ -- dims : Dimensionality of features returned by Inception
103
+ -- device : Device to run calculations
104
+ -- num_workers : Number of parallel dataloader workers
105
+
106
+ Returns:
107
+ -- A numpy array of dimension (num images, dims) that contains the
108
+ activations of the given tensor when feeding inception with the
109
+ query tensor.
110
+ """
111
+ model.eval()
112
+
113
+ if batch_size > len(files):
114
+ print(('Warning: batch size is bigger than the data size. '
115
+ 'Setting batch size to data size'))
116
+ batch_size = len(files)
117
+
118
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
119
+ dataloader = torch.utils.data.DataLoader(dataset,
120
+ batch_size=batch_size,
121
+ shuffle=False,
122
+ drop_last=False,
123
+ num_workers=num_workers)
124
+
125
+ pred_arr = np.empty((len(files), dims))
126
+
127
+ start_idx = 0
128
+
129
+ for batch in tqdm(dataloader):
130
+ batch = batch.to(device)
131
+
132
+ with torch.no_grad():
133
+ pred = model(batch)[0]
134
+
135
+ # If model output is not scalar, apply global spatial average pooling.
136
+ # This happens if you choose a dimensionality not equal 2048.
137
+ if pred.size(2) != 1 or pred.size(3) != 1:
138
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
139
+
140
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
141
+
142
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
143
+
144
+ start_idx = start_idx + pred.shape[0]
145
+
146
+ return pred_arr
147
+
148
+
149
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
150
+ """Numpy implementation of the Frechet Distance.
151
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
152
+ and X_2 ~ N(mu_2, C_2) is
153
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
154
+
155
+ Stable version by Dougal J. Sutherland.
156
+
157
+ Params:
158
+ -- mu1 : Numpy array containing the activations of a layer of the
159
+ inception net (like returned by the function 'get_predictions')
160
+ for generated samples.
161
+ -- mu2 : The sample mean over activations, precalculated on an
162
+ representative data set.
163
+ -- sigma1: The covariance matrix over activations for generated samples.
164
+ -- sigma2: The covariance matrix over activations, precalculated on an
165
+ representative data set.
166
+
167
+ Returns:
168
+ -- : The Frechet Distance.
169
+ """
170
+
171
+ mu1 = np.atleast_1d(mu1)
172
+ mu2 = np.atleast_1d(mu2)
173
+
174
+ sigma1 = np.atleast_2d(sigma1)
175
+ sigma2 = np.atleast_2d(sigma2)
176
+
177
+ assert mu1.shape == mu2.shape, \
178
+ 'Training and test mean vectors have different lengths'
179
+ assert sigma1.shape == sigma2.shape, \
180
+ 'Training and test covariances have different dimensions'
181
+
182
+ diff = mu1 - mu2
183
+
184
+ # Product might be almost singular
185
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
186
+ if not np.isfinite(covmean).all():
187
+ msg = ('fid calculation produces singular product; '
188
+ 'adding %s to diagonal of cov estimates') % eps
189
+ print(msg)
190
+ offset = np.eye(sigma1.shape[0]) * eps
191
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
192
+
193
+ # Numerical error might give slight imaginary component
194
+ if np.iscomplexobj(covmean):
195
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
196
+ m = np.max(np.abs(covmean.imag))
197
+ raise ValueError('Imaginary component {}'.format(m))
198
+ covmean = covmean.real
199
+
200
+ tr_covmean = np.trace(covmean)
201
+
202
+ return (diff.dot(diff) + np.trace(sigma1)
203
+ + np.trace(sigma2) - 2 * tr_covmean)
204
+
205
+
206
+ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
207
+ device='cpu', num_workers=1):
208
+ """Calculation of the statistics used by the FID.
209
+ Params:
210
+ -- files : List of image files paths
211
+ -- model : Instance of inception model
212
+ -- batch_size : The images numpy array is split into batches with
213
+ batch size batch_size. A reasonable batch size
214
+ depends on the hardware.
215
+ -- dims : Dimensionality of features returned by Inception
216
+ -- device : Device to run calculations
217
+ -- num_workers : Number of parallel dataloader workers
218
+
219
+ Returns:
220
+ -- mu : The mean over samples of the activations of the pool_3 layer of
221
+ the inception model.
222
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
223
+ the inception model.
224
+ """
225
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
226
+ mu = np.mean(act, axis=0)
227
+ sigma = np.cov(act, rowvar=False)
228
+ return mu, sigma
229
+
230
+
231
+ def compute_statistics_of_path(path, model, batch_size, dims, device,
232
+ num_workers=1):
233
+ if path.endswith('.npz'):
234
+ with np.load(path) as f:
235
+ m, s = f['mu'][:], f['sigma'][:]
236
+ else:
237
+ path = pathlib.Path(path)
238
+ files = sorted([file for ext in IMAGE_EXTENSIONS
239
+ for file in path.glob('*.{}'.format(ext))])
240
+ m, s = calculate_activation_statistics(files, model, batch_size,
241
+ dims, device, num_workers)
242
+
243
+ return m, s
244
+
245
+
246
+ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
247
+ """Calculates the FID of two paths"""
248
+ for p in paths:
249
+ if not os.path.exists(p):
250
+ raise RuntimeError('Invalid path: %s' % p)
251
+
252
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
253
+
254
+ model = InceptionV3([block_idx]).to(device)
255
+
256
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
257
+ dims, device, num_workers)
258
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
259
+ dims, device, num_workers)
260
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
261
+
262
+ return fid_value
263
+
264
+
265
+ def main():
266
+ args = parser.parse_args()
267
+
268
+ if args.device is None:
269
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
270
+ else:
271
+ device = torch.device(args.device)
272
+
273
+ if args.num_workers is None:
274
+ num_avail_cpus = len(os.sched_getaffinity(0))
275
+ num_workers = min(num_avail_cpus, 8)
276
+ else:
277
+ num_workers = args.num_workers
278
+
279
+ fid_value = calculate_fid_given_paths(args.path,
280
+ args.batch_size,
281
+ device,
282
+ args.dims,
283
+ num_workers)
284
+ print('FID: ', fid_value)
285
+
286
+
287
+ if __name__ == '__main__':
288
+ main()
asp/util/general_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(".")
3
+
4
+ import time, os
5
+ from functools import wraps
6
+ import argparse
7
+ import inspect
8
+ import traceback
9
+
10
+ def time_it(func):
11
+ @wraps(func)
12
+ def wrapper(*args, **kwargs):
13
+ start_time = time.time()
14
+ result = func(*args, **kwargs)
15
+ end_time = time.time()
16
+ elapsed_time = end_time - start_time
17
+ print(f"Function '{func.__name__}' executed in {elapsed_time:.6f} seconds.")
18
+ return result
19
+ return wrapper
20
+
21
+ def try_wrapper(function, filename, log_path):
22
+ try:
23
+ return function()
24
+ except Exception as e:
25
+ error_trace = traceback.format_exc()
26
+
27
+ with open(log_path, 'a') as log_file:
28
+ log_file.write(f"{filename}: {error_trace}\n")
29
+ print(f"Error in {filename}:\n{error_trace}")
30
+
31
+ def parse_args(main_function):
32
+ parser = argparse.ArgumentParser()
33
+
34
+ used_short_versions = set("h")
35
+
36
+ signature = inspect.signature(main_function)
37
+ for param_name, param in signature.parameters.items():
38
+ short_version = param_name[0]
39
+ if short_version in used_short_versions or not short_version.isalpha():
40
+ for char in param_name[1:]:
41
+ short_version = char
42
+ if char.isalpha() and short_version not in used_short_versions:
43
+ break
44
+ else:
45
+ short_version = None
46
+
47
+ if short_version:
48
+ used_short_versions.add(short_version)
49
+ param_call = (f'-{short_version}', f'--{param_name}')
50
+ else:
51
+ param_call = (f'--{param_name}',)
52
+
53
+ if param.default is not inspect.Parameter.empty:
54
+ if param.default is not None:
55
+ param_type = type(param.default)
56
+ else:
57
+ param_type = str
58
+ parser.add_argument(*param_call, type=param_type, default=param.default,
59
+ help=f"Automatically detected argument: {param_name}, default: {param.default}")
60
+ else:
61
+ parser.add_argument(*param_call, required=True,
62
+ help=f"Required argument: {param_name}")
63
+
64
+ args = parser.parse_args()
65
+
66
+ return args
67
+
68
+ def assert_file_exist(*args):
69
+ path = os.path.join(*args)
70
+ if not os.path.exists(path):
71
+ raise Exception(f"File {path} does not exist")
72
+
73
+ return path
asp/util/get_data.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import tarfile
4
+ import requests
5
+ from warnings import warn
6
+ from zipfile import ZipFile
7
+ from bs4 import BeautifulSoup
8
+ from os.path import abspath, isdir, join, basename
9
+
10
+
11
+ class GetData(object):
12
+ """A Python script for downloading CycleGAN or pix2pix datasets.
13
+
14
+ Parameters:
15
+ technique (str) -- One of: 'cyclegan' or 'pix2pix'.
16
+ verbose (bool) -- If True, print additional information.
17
+
18
+ Examples:
19
+ >>> from util.get_data import GetData
20
+ >>> gd = GetData(technique='cyclegan')
21
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
22
+
23
+ Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
24
+ and 'scripts/download_cyclegan_model.sh'.
25
+ """
26
+
27
+ def __init__(self, technique='cyclegan', verbose=True):
28
+ url_dict = {
29
+ 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
30
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
31
+ }
32
+ self.url = url_dict.get(technique.lower())
33
+ self._verbose = verbose
34
+
35
+ def _print(self, text):
36
+ if self._verbose:
37
+ print(text)
38
+
39
+ @staticmethod
40
+ def _get_options(r):
41
+ soup = BeautifulSoup(r.text, 'lxml')
42
+ options = [h.text for h in soup.find_all('a', href=True)
43
+ if h.text.endswith(('.zip', 'tar.gz'))]
44
+ return options
45
+
46
+ def _present_options(self):
47
+ r = requests.get(self.url)
48
+ options = self._get_options(r)
49
+ print('Options:\n')
50
+ for i, o in enumerate(options):
51
+ print("{0}: {1}".format(i, o))
52
+ choice = input("\nPlease enter the number of the "
53
+ "dataset above you wish to download:")
54
+ return options[int(choice)]
55
+
56
+ def _download_data(self, dataset_url, save_path):
57
+ if not isdir(save_path):
58
+ os.makedirs(save_path)
59
+
60
+ base = basename(dataset_url)
61
+ temp_save_path = join(save_path, base)
62
+
63
+ with open(temp_save_path, "wb") as f:
64
+ r = requests.get(dataset_url)
65
+ f.write(r.content)
66
+
67
+ if base.endswith('.tar.gz'):
68
+ obj = tarfile.open(temp_save_path)
69
+ elif base.endswith('.zip'):
70
+ obj = ZipFile(temp_save_path, 'r')
71
+ else:
72
+ raise ValueError("Unknown File Type: {0}.".format(base))
73
+
74
+ self._print("Unpacking Data...")
75
+ obj.extractall(save_path)
76
+ obj.close()
77
+ os.remove(temp_save_path)
78
+
79
+ def get(self, save_path, dataset=None):
80
+ """
81
+
82
+ Download a dataset.
83
+
84
+ Parameters:
85
+ save_path (str) -- A directory to save the data to.
86
+ dataset (str) -- (optional). A specific dataset to download.
87
+ Note: this must include the file extension.
88
+ If None, options will be presented for you
89
+ to choose from.
90
+
91
+ Returns:
92
+ save_path_full (str) -- the absolute path to the downloaded data.
93
+
94
+ """
95
+ if dataset is None:
96
+ selected_dataset = self._present_options()
97
+ else:
98
+ selected_dataset = dataset
99
+
100
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
101
+
102
+ if isdir(save_path_full):
103
+ warn("\n'{0}' already exists. Voiding Download.".format(
104
+ save_path_full))
105
+ else:
106
+ self._print('Downloading Data...')
107
+ url = "{0}/{1}".format(self.url, selected_dataset)
108
+ self._download_data(url, save_path=save_path)
109
+
110
+ return abspath(save_path_full)
asp/util/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ refresh (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
asp/util/image_pool.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class ImagePool():
6
+ """This class implements an image buffer that stores previously generated images.
7
+
8
+ This buffer enables us to update discriminators using a history of generated images
9
+ rather than the ones produced by the latest generators.
10
+ """
11
+
12
+ def __init__(self, pool_size):
13
+ """Initialize the ImagePool class
14
+
15
+ Parameters:
16
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17
+ """
18
+ self.pool_size = pool_size
19
+ if self.pool_size > 0: # create an empty pool
20
+ self.num_imgs = 0
21
+ self.images = []
22
+
23
+ def query(self, images):
24
+ """Return an image from the pool.
25
+
26
+ Parameters:
27
+ images: the latest generated images from the generator
28
+
29
+ Returns images from the buffer.
30
+
31
+ By 50/100, the buffer will return input images.
32
+ By 50/100, the buffer will return images previously stored in the buffer,
33
+ and insert the current images to the buffer.
34
+ """
35
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
36
+ return images
37
+ return_images = []
38
+ for image in images:
39
+ image = torch.unsqueeze(image.data, 0)
40
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41
+ self.num_imgs = self.num_imgs + 1
42
+ self.images.append(image)
43
+ return_images.append(image)
44
+ else:
45
+ p = random.uniform(0, 1)
46
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48
+ tmp = self.images[random_id].clone()
49
+ self.images[random_id] = image
50
+ return_images.append(tmp)
51
+ else: # by another 50% chance, the buffer will return the current image
52
+ return_images.append(image)
53
+ return_images = torch.cat(return_images, 0) # collect all the images and return
54
+ return return_images
asp/util/inception.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = _inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def _inception_v3(*args, **kwargs):
167
+ """Wraps `torchvision.models.inception_v3`
168
+
169
+ Skips default weight inititialization if supported by torchvision version.
170
+ See https://github.com/mseitzer/pytorch-fid/issues/28.
171
+ """
172
+ try:
173
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
174
+ except ValueError:
175
+ # Just a caution against weird version strings
176
+ version = (0,)
177
+
178
+ if version >= (0, 6):
179
+ kwargs['init_weights'] = False
180
+
181
+ return torchvision.models.inception_v3(*args, **kwargs)
182
+
183
+
184
+ def fid_inception_v3():
185
+ """Build pretrained Inception model for FID computation
186
+
187
+ The Inception model for FID computation uses a different set of weights
188
+ and has a slightly different structure than torchvision's Inception.
189
+
190
+ This method first constructs torchvision's Inception and then patches the
191
+ necessary parts that are different in the FID Inception model.
192
+ """
193
+ inception = _inception_v3(num_classes=1008,
194
+ aux_logits=False,
195
+ pretrained=False)
196
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
197
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
198
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
199
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
200
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
201
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
202
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
203
+ inception.Mixed_7b = FIDInceptionE_1(1280)
204
+ inception.Mixed_7c = FIDInceptionE_2(2048)
205
+
206
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
207
+ inception.load_state_dict(state_dict)
208
+ return inception
209
+
210
+
211
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
212
+ """InceptionA block patched for FID computation"""
213
+ def __init__(self, in_channels, pool_features):
214
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
215
+
216
+ def forward(self, x):
217
+ branch1x1 = self.branch1x1(x)
218
+
219
+ branch5x5 = self.branch5x5_1(x)
220
+ branch5x5 = self.branch5x5_2(branch5x5)
221
+
222
+ branch3x3dbl = self.branch3x3dbl_1(x)
223
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
224
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
225
+
226
+ # Patch: Tensorflow's average pool does not use the padded zero's in
227
+ # its average calculation
228
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
229
+ count_include_pad=False)
230
+ branch_pool = self.branch_pool(branch_pool)
231
+
232
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
233
+ return torch.cat(outputs, 1)
234
+
235
+
236
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
237
+ """InceptionC block patched for FID computation"""
238
+ def __init__(self, in_channels, channels_7x7):
239
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
240
+
241
+ def forward(self, x):
242
+ branch1x1 = self.branch1x1(x)
243
+
244
+ branch7x7 = self.branch7x7_1(x)
245
+ branch7x7 = self.branch7x7_2(branch7x7)
246
+ branch7x7 = self.branch7x7_3(branch7x7)
247
+
248
+ branch7x7dbl = self.branch7x7dbl_1(x)
249
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
250
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
251
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
252
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
253
+
254
+ # Patch: Tensorflow's average pool does not use the padded zero's in
255
+ # its average calculation
256
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
257
+ count_include_pad=False)
258
+ branch_pool = self.branch_pool(branch_pool)
259
+
260
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
261
+ return torch.cat(outputs, 1)
262
+
263
+
264
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
265
+ """First InceptionE block patched for FID computation"""
266
+ def __init__(self, in_channels):
267
+ super(FIDInceptionE_1, self).__init__(in_channels)
268
+
269
+ def forward(self, x):
270
+ branch1x1 = self.branch1x1(x)
271
+
272
+ branch3x3 = self.branch3x3_1(x)
273
+ branch3x3 = [
274
+ self.branch3x3_2a(branch3x3),
275
+ self.branch3x3_2b(branch3x3),
276
+ ]
277
+ branch3x3 = torch.cat(branch3x3, 1)
278
+
279
+ branch3x3dbl = self.branch3x3dbl_1(x)
280
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
281
+ branch3x3dbl = [
282
+ self.branch3x3dbl_3a(branch3x3dbl),
283
+ self.branch3x3dbl_3b(branch3x3dbl),
284
+ ]
285
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
286
+
287
+ # Patch: Tensorflow's average pool does not use the padded zero's in
288
+ # its average calculation
289
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
290
+ count_include_pad=False)
291
+ branch_pool = self.branch_pool(branch_pool)
292
+
293
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
294
+ return torch.cat(outputs, 1)
295
+
296
+
297
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
298
+ """Second InceptionE block patched for FID computation"""
299
+ def __init__(self, in_channels):
300
+ super(FIDInceptionE_2, self).__init__(in_channels)
301
+
302
+ def forward(self, x):
303
+ branch1x1 = self.branch1x1(x)
304
+
305
+ branch3x3 = self.branch3x3_1(x)
306
+ branch3x3 = [
307
+ self.branch3x3_2a(branch3x3),
308
+ self.branch3x3_2b(branch3x3),
309
+ ]
310
+ branch3x3 = torch.cat(branch3x3, 1)
311
+
312
+ branch3x3dbl = self.branch3x3dbl_1(x)
313
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
314
+ branch3x3dbl = [
315
+ self.branch3x3dbl_3a(branch3x3dbl),
316
+ self.branch3x3dbl_3b(branch3x3dbl),
317
+ ]
318
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
319
+
320
+ # Patch: The FID Inception model uses max pooling instead of average
321
+ # pooling. This is likely an error in this specific Inception
322
+ # implementation, as other Inception models use average pooling here
323
+ # (which matches the description in the paper).
324
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
325
+ branch_pool = self.branch_pool(branch_pool)
326
+
327
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
328
+ return torch.cat(outputs, 1)
asp/util/kid_score.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Calculates the Kernel Inception Distance (KID) to evalulate GANs
3
+ """
4
+ import os
5
+ import pathlib
6
+ import sys
7
+ from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
8
+
9
+ import numpy as np
10
+ import torch
11
+ from sklearn.metrics.pairwise import polynomial_kernel
12
+ from scipy import linalg
13
+ from PIL import Image
14
+ from torch.nn.functional import adaptive_avg_pool2d
15
+
16
+ try:
17
+ from tqdm import tqdm
18
+ except ImportError:
19
+ # If not tqdm is not available, provide a mock version of it
20
+ def tqdm(x): return x
21
+
22
+ # from models.inception import InceptionV3
23
+ # from models.lenet import LeNet5
24
+
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from torchvision import models
28
+
29
+
30
+ class InceptionV3(nn.Module):
31
+ """Pretrained InceptionV3 network returning feature maps"""
32
+
33
+ # Index of default block of inception to return,
34
+ # corresponds to output of final average pooling
35
+ DEFAULT_BLOCK_INDEX = 3
36
+
37
+ # Maps feature dimensionality to their output blocks indices
38
+ BLOCK_INDEX_BY_DIM = {
39
+ 64: 0, # First max pooling features
40
+ 192: 1, # Second max pooling featurs
41
+ 768: 2, # Pre-aux classifier features
42
+ 2048: 3 # Final average pooling features
43
+ }
44
+
45
+ def __init__(self,
46
+ output_blocks=[DEFAULT_BLOCK_INDEX],
47
+ resize_input=True,
48
+ normalize_input=True,
49
+ requires_grad=False):
50
+ """Build pretrained InceptionV3
51
+
52
+ Parameters
53
+ ----------
54
+ output_blocks : list of int
55
+ Indices of blocks to return features of. Possible values are:
56
+ - 0: corresponds to output of first max pooling
57
+ - 1: corresponds to output of second max pooling
58
+ - 2: corresponds to output which is fed to aux classifier
59
+ - 3: corresponds to output of final average pooling
60
+ resize_input : bool
61
+ If true, bilinearly resizes input to width and height 299 before
62
+ feeding input to model. As the network without fully connected
63
+ layers is fully convolutional, it should be able to handle inputs
64
+ of arbitrary size, so resizing might not be strictly needed
65
+ normalize_input : bool
66
+ If true, scales the input from range (0, 1) to the range the
67
+ pretrained Inception network expects, namely (-1, 1)
68
+ requires_grad : bool
69
+ If true, parameters of the model require gradient. Possibly useful
70
+ for finetuning the network
71
+ """
72
+ super(InceptionV3, self).__init__()
73
+
74
+ self.resize_input = resize_input
75
+ self.normalize_input = normalize_input
76
+ self.output_blocks = sorted(output_blocks)
77
+ self.last_needed_block = max(output_blocks)
78
+
79
+ assert self.last_needed_block <= 3, \
80
+ 'Last possible output block index is 3'
81
+
82
+ self.blocks = nn.ModuleList()
83
+
84
+ inception = models.inception_v3(pretrained=True)
85
+
86
+ # Block 0: input to maxpool1
87
+ block0 = [
88
+ inception.Conv2d_1a_3x3,
89
+ inception.Conv2d_2a_3x3,
90
+ inception.Conv2d_2b_3x3,
91
+ nn.MaxPool2d(kernel_size=3, stride=2)
92
+ ]
93
+ self.blocks.append(nn.Sequential(*block0))
94
+
95
+ # Block 1: maxpool1 to maxpool2
96
+ if self.last_needed_block >= 1:
97
+ block1 = [
98
+ inception.Conv2d_3b_1x1,
99
+ inception.Conv2d_4a_3x3,
100
+ nn.MaxPool2d(kernel_size=3, stride=2)
101
+ ]
102
+ self.blocks.append(nn.Sequential(*block1))
103
+
104
+ # Block 2: maxpool2 to aux classifier
105
+ if self.last_needed_block >= 2:
106
+ block2 = [
107
+ inception.Mixed_5b,
108
+ inception.Mixed_5c,
109
+ inception.Mixed_5d,
110
+ inception.Mixed_6a,
111
+ inception.Mixed_6b,
112
+ inception.Mixed_6c,
113
+ inception.Mixed_6d,
114
+ inception.Mixed_6e,
115
+ ]
116
+ self.blocks.append(nn.Sequential(*block2))
117
+
118
+ # Block 3: aux classifier to final avgpool
119
+ if self.last_needed_block >= 3:
120
+ block3 = [
121
+ inception.Mixed_7a,
122
+ inception.Mixed_7b,
123
+ inception.Mixed_7c,
124
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
125
+ ]
126
+ self.blocks.append(nn.Sequential(*block3))
127
+
128
+ for param in self.parameters():
129
+ param.requires_grad = requires_grad
130
+
131
+ def forward(self, inp):
132
+ """Get Inception feature maps
133
+
134
+ Parameters
135
+ ----------
136
+ inp : torch.autograd.Variable
137
+ Input tensor of shape Bx3xHxW. Values are expected to be in
138
+ range (0.0, 1.0)
139
+
140
+ Returns
141
+ -------
142
+ List of torch.autograd.Variable, corresponding to the selected output
143
+ block, sorted ascending by index
144
+ """
145
+ outp = []
146
+ x = inp
147
+
148
+ if self.resize_input:
149
+ x = F.interpolate(x,
150
+ size=(299, 299),
151
+ mode='bilinear',
152
+ align_corners=False)
153
+
154
+ if self.normalize_input:
155
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
156
+
157
+ for idx, block in enumerate(self.blocks):
158
+ x = block(x)
159
+ if idx in self.output_blocks:
160
+ outp.append(x)
161
+
162
+ if idx == self.last_needed_block:
163
+ break
164
+
165
+ return outp
166
+
167
+
168
+ def get_activations(files, model, batch_size=50, dims=2048,
169
+ cuda=False, verbose=False):
170
+ """Calculates the activations of the pool_3 layer for all images.
171
+
172
+ Params:
173
+ -- files : List of image files paths
174
+ -- model : Instance of inception model
175
+ -- batch_size : Batch size of images for the model to process at once.
176
+ Make sure that the number of samples is a multiple of
177
+ the batch size, otherwise some samples are ignored. This
178
+ behavior is retained to match the original FID score
179
+ implementation.
180
+ -- dims : Dimensionality of features returned by Inception
181
+ -- cuda : If set to True, use GPU
182
+ -- verbose : If set to True and parameter out_step is given, the number
183
+ of calculated batches is reported.
184
+ Returns:
185
+ -- A numpy array of dimension (num images, dims) that contains the
186
+ activations of the given tensor when feeding inception with the
187
+ query tensor.
188
+ """
189
+ model.eval()
190
+
191
+ is_numpy = True if type(files[0]) == np.ndarray else False
192
+
193
+ if len(files) % batch_size != 0:
194
+ print(('Warning: number of images is not a multiple of the '
195
+ 'batch size. Some samples are going to be ignored.'))
196
+ if batch_size > len(files):
197
+ print(('Warning: batch size is bigger than the data size. '
198
+ 'Setting batch size to data size'))
199
+ batch_size = len(files)
200
+
201
+ n_batches = len(files) // batch_size
202
+ n_used_imgs = n_batches * batch_size
203
+
204
+ pred_arr = np.empty((n_used_imgs, dims))
205
+
206
+ for i in tqdm(range(n_batches)):
207
+ if verbose:
208
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)
209
+ start = i * batch_size
210
+ end = start + batch_size
211
+ if is_numpy:
212
+ images = np.copy(files[start:end]) + 1
213
+ images /= 2.
214
+ else:
215
+ images = [np.array(Image.open(str(f))) for f in files[start:end]]
216
+ images = np.stack(images).astype(np.float32) / 255.
217
+ # Reshape to (n_images, 3, height, width)
218
+ images = images.transpose((0, 3, 1, 2))
219
+
220
+ batch = torch.from_numpy(images).type(torch.FloatTensor)
221
+ if cuda:
222
+ batch = batch.cuda()
223
+ pred = model(batch)[0]
224
+
225
+ # If model output is not scalar, apply global spatial average pooling.
226
+ # This happens if you choose a dimensionality not equal 2048.
227
+ if pred.shape[2] != 1 or pred.shape[3] != 1:
228
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
229
+
230
+ pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
231
+
232
+ if verbose:
233
+ print('done', np.min(images))
234
+
235
+ return pred_arr
236
+
237
+
238
+ def extract_lenet_features(imgs, net):
239
+ net.eval()
240
+ feats = []
241
+ imgs = imgs.reshape([-1, 100] + list(imgs.shape[1:]))
242
+ if imgs[0].min() < -0.001:
243
+ imgs = (imgs + 1)/2.0
244
+ print(imgs.shape, imgs.min(), imgs.max())
245
+ imgs = torch.from_numpy(imgs)
246
+ for i, images in enumerate(imgs):
247
+ feats.append(net.extract_features(images).detach().cpu().numpy())
248
+ feats = np.vstack(feats)
249
+ return feats
250
+
251
+
252
+ def _compute_activations(path, model, batch_size, dims, cuda, model_type):
253
+ if not type(path) == np.ndarray:
254
+ import glob
255
+ jpg = os.path.join(path, '*.jpg')
256
+ png = os.path.join(path, '*.png')
257
+ path = glob.glob(jpg) + glob.glob(png)
258
+ if len(path) > 50000:
259
+ import random
260
+ random.shuffle(path)
261
+ path = path[:50000]
262
+ if model_type == 'inception':
263
+ act = get_activations(path, model, batch_size, dims, cuda)
264
+ elif model_type == 'lenet':
265
+ act = extract_lenet_features(path, model)
266
+ return act
267
+
268
+
269
+ def calculate_kid_given_paths(paths, batch_size, cuda, dims, model_type='inception'):
270
+ """Calculates the KID of two paths"""
271
+ pths = []
272
+ for p in paths:
273
+ if not os.path.exists(p):
274
+ raise RuntimeError('Invalid path: %s' % p)
275
+ if os.path.isdir(p):
276
+ pths.append(p)
277
+ elif p.endswith('.npy'):
278
+ np_imgs = np.load(p)
279
+ if np_imgs.shape[0] > 50000: np_imgs = np_imgs[np.random.permutation(np.arange(np_imgs.shape[0]))][:50000]
280
+ pths.append(np_imgs)
281
+
282
+ if model_type == 'inception':
283
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
284
+ model = InceptionV3([block_idx])
285
+ elif model_type == 'lenet':
286
+ model = LeNet5()
287
+ model.load_state_dict(torch.load('./models/lenet.pth'))
288
+ if cuda:
289
+ model.cuda()
290
+
291
+ act_true = _compute_activations(pths[0], model, batch_size, dims, cuda, model_type)
292
+ pths = pths[1:]
293
+ results = []
294
+ for j, pth in enumerate(pths):
295
+ print(paths[j+1])
296
+ actj = _compute_activations(pth, model, batch_size, dims, cuda, model_type)
297
+ kid_values = polynomial_mmd_averages(act_true, actj, n_subsets=100, subset_size=min(act_true.shape[0], 100))
298
+ results.append((paths[j+1], kid_values[0].mean(), kid_values[0].std()))
299
+ return results
300
+
301
+ def _sqn(arr):
302
+ flat = np.ravel(arr)
303
+ return flat.dot(flat)
304
+
305
+
306
+ def polynomial_mmd_averages(codes_g, codes_r, n_subsets=50, subset_size=1000,
307
+ ret_var=True, output=sys.stdout, **kernel_args):
308
+ m = min(codes_g.shape[0], codes_r.shape[0])
309
+ mmds = np.zeros(n_subsets)
310
+ if ret_var:
311
+ vars = np.zeros(n_subsets)
312
+ choice = np.random.choice
313
+
314
+ with tqdm(range(n_subsets), desc='MMD', file=output) as bar:
315
+ for i in bar:
316
+ g = codes_g[choice(len(codes_g), subset_size, replace=False)]
317
+ r = codes_r[choice(len(codes_r), subset_size, replace=False)]
318
+ o = polynomial_mmd(g, r, **kernel_args, var_at_m=m, ret_var=ret_var)
319
+ if ret_var:
320
+ mmds[i], vars[i] = o
321
+ else:
322
+ mmds[i] = o
323
+ bar.set_postfix({'mean': mmds[:i+1].mean()})
324
+ return (mmds, vars) if ret_var else mmds
325
+
326
+
327
+ def polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1,
328
+ var_at_m=None, ret_var=True):
329
+ # use k(x, y) = (gamma <x, y> + coef0)^degree
330
+ # default gamma is 1 / dim
331
+ X = codes_g
332
+ Y = codes_r
333
+
334
+ K_XX = polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
335
+ K_YY = polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
336
+ K_XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)
337
+
338
+ return _mmd2_and_variance(K_XX, K_XY, K_YY,
339
+ var_at_m=var_at_m, ret_var=ret_var)
340
+
341
+ def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
342
+ mmd_est='unbiased', block_size=1024,
343
+ var_at_m=None, ret_var=True):
344
+ # based on
345
+ # https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
346
+ # but changed to not compute the full kernel matrix at once
347
+ m = K_XX.shape[0]
348
+ assert K_XX.shape == (m, m)
349
+ assert K_XY.shape == (m, m)
350
+ assert K_YY.shape == (m, m)
351
+ if var_at_m is None:
352
+ var_at_m = m
353
+
354
+ # Get the various sums of kernels that we'll use
355
+ # Kts drop the diagonal, but we don't need to compute them explicitly
356
+ if unit_diagonal:
357
+ diag_X = diag_Y = 1
358
+ sum_diag_X = sum_diag_Y = m
359
+ sum_diag2_X = sum_diag2_Y = m
360
+ else:
361
+ diag_X = np.diagonal(K_XX)
362
+ diag_Y = np.diagonal(K_YY)
363
+
364
+ sum_diag_X = diag_X.sum()
365
+ sum_diag_Y = diag_Y.sum()
366
+
367
+ sum_diag2_X = _sqn(diag_X)
368
+ sum_diag2_Y = _sqn(diag_Y)
369
+
370
+ Kt_XX_sums = K_XX.sum(axis=1) - diag_X
371
+ Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
372
+ K_XY_sums_0 = K_XY.sum(axis=0)
373
+ K_XY_sums_1 = K_XY.sum(axis=1)
374
+
375
+ Kt_XX_sum = Kt_XX_sums.sum()
376
+ Kt_YY_sum = Kt_YY_sums.sum()
377
+ K_XY_sum = K_XY_sums_0.sum()
378
+
379
+ if mmd_est == 'biased':
380
+ mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
381
+ + (Kt_YY_sum + sum_diag_Y) / (m * m)
382
+ - 2 * K_XY_sum / (m * m))
383
+ else:
384
+ assert mmd_est in {'unbiased', 'u-statistic'}
385
+ mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m-1))
386
+ if mmd_est == 'unbiased':
387
+ mmd2 -= 2 * K_XY_sum / (m * m)
388
+ else:
389
+ mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m-1))
390
+
391
+ if not ret_var:
392
+ return mmd2
393
+
394
+ Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
395
+ Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
396
+ K_XY_2_sum = _sqn(K_XY)
397
+
398
+ dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
399
+ dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)
400
+
401
+ m1 = m - 1
402
+ m2 = m - 2
403
+ zeta1_est = (
404
+ 1 / (m * m1 * m2) * (
405
+ _sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
406
+ - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
407
+ + 1 / (m * m * m1) * (
408
+ _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
409
+ - 2 / m**4 * K_XY_sum**2
410
+ - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
411
+ + 2 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
412
+ )
413
+ zeta2_est = (
414
+ 1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
415
+ - 1 / (m * m1)**2 * (Kt_XX_sum**2 + Kt_YY_sum**2)
416
+ + 2 / (m * m) * K_XY_2_sum
417
+ - 2 / m**4 * K_XY_sum**2
418
+ - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
419
+ + 4 / (m**3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
420
+ )
421
+ var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
422
+ + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)
423
+
424
+ return mmd2, var_est
425
+
426
+
427
+ if __name__ == '__main__':
428
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
429
+ parser.add_argument('--true', type=str, required=True,
430
+ help=('Path to the true images'))
431
+ parser.add_argument('--fake', type=str, nargs='+', required=True,
432
+ help=('Path to the generated images'))
433
+ parser.add_argument('--batch-size', type=int, default=50,
434
+ help='Batch size to use')
435
+ parser.add_argument('--dims', type=int, default=2048,
436
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
437
+ help=('Dimensionality of Inception features to use. '
438
+ 'By default, uses pool3 features'))
439
+ parser.add_argument('-c', '--gpu', default='', type=str,
440
+ help='GPU to use (leave blank for CPU only)')
441
+ parser.add_argument('--model', default='inception', type=str,
442
+ help='inception or lenet')
443
+ args = parser.parse_args()
444
+ print(args)
445
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
446
+ paths = [args.true] + args.fake
447
+
448
+ results = calculate_kid_given_paths(paths, args.batch_size, args.gpu != '', args.dims, model_type=args.model)
449
+ for p, m, s in results:
450
+ print('KID mean std (%s): %.4f %.4f' % (p, m, s))
asp/util/perceptual.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
5
+ # Copyright (C) 2020 NVIDIA Corporation. All rights reserved
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+ from torch import nn
10
+
11
+
12
+ def apply_imagenet_normalization(input):
13
+ r"""Normalize using ImageNet mean and std.
14
+
15
+ Args:
16
+ input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1].
17
+
18
+ Returns:
19
+ Normalized inputs using the ImageNet normalization.
20
+ """
21
+ # normalize the input back to [0, 1]
22
+ normalized_input = (input + 1) / 2
23
+ # normalize the input using the ImageNet mean and std
24
+ mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
25
+ std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
26
+ output = (normalized_input - mean) / std
27
+ return output
28
+
29
+
30
+ class PerceptualHashValue(nn.Module):
31
+ """Perceptual loss initialization.
32
+
33
+ Args:
34
+ cfg (Config): Configuration file.
35
+ network (str) : The name of the loss network: 'vgg16' | 'vgg19'.
36
+ layers (str or list of str) : The layers used to compute the loss.
37
+ weights (float or list of float : The loss weights of each layer.
38
+ criterion (str): The type of distance function: 'l1' | 'l2'.
39
+ resize (bool) : If ``True``, resize the input images to 224x224.
40
+ resize_mode (str): Algorithm used for resizing.
41
+ instance_normalized (bool): If ``True``, applies instance normalization
42
+ to the feature maps before computing the distance.
43
+ num_scales (int): The loss will be evaluated at original size and
44
+ this many times downsampled sizes.
45
+ """
46
+
47
+ def __init__(self, T=0.005, network='vgg19', layers='relu_4_1', resize=False, resize_mode='bilinear',
48
+ instance_normalized=False):
49
+ super().__init__()
50
+ if isinstance(layers, str):
51
+ layers = [layers]
52
+
53
+ if network == 'vgg19':
54
+ self.model = _vgg19(layers)
55
+ elif network == 'vgg16':
56
+ self.model = _vgg16(layers)
57
+ elif network == 'alexnet':
58
+ self.model = _alexnet(layers)
59
+ elif network == 'inception_v3':
60
+ self.model = _inception_v3(layers)
61
+ elif network == 'resnet50':
62
+ self.model = _resnet50(layers)
63
+ elif network == 'robust_resnet50':
64
+ self.model = _robust_resnet50(layers)
65
+ elif network == 'vgg_face_dag':
66
+ self.model = _vgg_face_dag(layers)
67
+ else:
68
+ raise ValueError('Network %s is not recognized' % network)
69
+
70
+ self.T = T
71
+ self.layers = layers
72
+ self.resize = resize
73
+ self.resize_mode = resize_mode
74
+ self.instance_normalized = instance_normalized
75
+ print('Perceptual Hash Value:')
76
+ print('\tMode: {}'.format(network))
77
+
78
+ def forward(self, inp, target):
79
+ r"""Perceptual loss forward.
80
+
81
+ Args:
82
+ inp (4D tensor) : Input tensor.
83
+ target (4D tensor) : Ground truth tensor, same shape as the input.
84
+
85
+ Returns:
86
+ (scalar tensor) : The perceptual loss.
87
+ """
88
+ # Perceptual loss should operate in eval mode by default.
89
+ self.model.eval()
90
+ inp, target = \
91
+ apply_imagenet_normalization(inp), \
92
+ apply_imagenet_normalization(target)
93
+ if self.resize:
94
+ inp = F.interpolate(
95
+ inp, mode=self.resize_mode, size=(224, 224),
96
+ align_corners=False)
97
+ target = F.interpolate(
98
+ target, mode=self.resize_mode, size=(224, 224),
99
+ align_corners=False)
100
+
101
+ # Evaluate perceptual loss at each scale.
102
+ loss = 0
103
+ input_features, target_features = \
104
+ self.model(inp), self.model(target)
105
+
106
+ hpv_list = []
107
+ for layer in self.layers:
108
+ # Example per-layer VGG19 loss values after applying
109
+ # [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting.
110
+ # relu_1_1, 0.014698, 0.47
111
+ # relu_2_1, 0.085817, 1.37
112
+ # relu_3_1, 0.349977, 2.8
113
+ # relu_4_1, 0.544188, 2.176
114
+ # relu_5_1, 0.906261, 0.906
115
+ input_feature = input_features[layer]
116
+ target_feature = target_features[layer].detach()
117
+ if self.instance_normalized:
118
+ input_feature = F.instance_norm(input_feature)
119
+ target_feature = F.instance_norm(target_feature)
120
+
121
+ # We are ignoring the spatial dimensions
122
+ B, C = input_feature.shape[:2]
123
+ inp_avg = torch.mean(input_feature.view(B, C, -1), -1)
124
+ tgt_avg = torch.mean(target_feature.view(B, C, -1), -1)
125
+ abs_dif = torch.abs(inp_avg - tgt_avg)
126
+ hpv = torch.sum(abs_dif > self.T).item() / (B * C)
127
+ hpv_list.append(hpv)
128
+
129
+ return hpv_list
130
+
131
+
132
+ class _PerceptualNetwork(nn.Module):
133
+ r"""The network that extracts features to compute the perceptual loss.
134
+
135
+ Args:
136
+ network (nn.Sequential) : The network that extracts features.
137
+ layer_name_mapping (dict) : The dictionary that
138
+ maps a layer's index to its name.
139
+ layers (list of str): The list of layer names that we are using.
140
+ """
141
+
142
+ def __init__(self, network, layer_name_mapping, layers):
143
+ super().__init__()
144
+ assert isinstance(network, nn.Sequential), \
145
+ 'The network needs to be of type "nn.Sequential".'
146
+ self.network = network
147
+ self.layer_name_mapping = layer_name_mapping
148
+ self.layers = layers
149
+ for param in self.parameters():
150
+ param.requires_grad = False
151
+
152
+ def forward(self, x):
153
+ r"""Extract perceptual features."""
154
+ output = {}
155
+ for i, layer in enumerate(self.network):
156
+ x = layer(x)
157
+ layer_name = self.layer_name_mapping.get(i, None)
158
+ if layer_name in self.layers:
159
+ # If the current layer is used by the perceptual loss.
160
+ output[layer_name] = x
161
+ return output
162
+
163
+
164
+ def _vgg19(layers):
165
+ r"""Get vgg19 layers"""
166
+ network = torchvision.models.vgg19(pretrained=True).features
167
+ layer_name_mapping = {1: 'relu_1_1',
168
+ 3: 'relu_1_2',
169
+ 6: 'relu_2_1',
170
+ 8: 'relu_2_2',
171
+ 11: 'relu_3_1',
172
+ 13: 'relu_3_2',
173
+ 15: 'relu_3_3',
174
+ 17: 'relu_3_4',
175
+ 20: 'relu_4_1',
176
+ 22: 'relu_4_2',
177
+ 24: 'relu_4_3',
178
+ 26: 'relu_4_4',
179
+ 29: 'relu_5_1'}
180
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
181
+
182
+
183
+ def _vgg16(layers):
184
+ r"""Get vgg16 layers"""
185
+ network = torchvision.models.vgg16(pretrained=True).features
186
+ layer_name_mapping = {1: 'relu_1_1',
187
+ 3: 'relu_1_2',
188
+ 6: 'relu_2_1',
189
+ 8: 'relu_2_2',
190
+ 11: 'relu_3_1',
191
+ 13: 'relu_3_2',
192
+ 15: 'relu_3_3',
193
+ 18: 'relu_4_1',
194
+ 20: 'relu_4_2',
195
+ 22: 'relu_4_3',
196
+ 25: 'relu_5_1'}
197
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
198
+
199
+
200
+ def _alexnet(layers):
201
+ r"""Get alexnet layers"""
202
+ network = torchvision.models.alexnet(pretrained=True).features
203
+ layer_name_mapping = {0: 'conv_1',
204
+ 1: 'relu_1',
205
+ 3: 'conv_2',
206
+ 4: 'relu_2',
207
+ 6: 'conv_3',
208
+ 7: 'relu_3',
209
+ 8: 'conv_4',
210
+ 9: 'relu_4',
211
+ 10: 'conv_5',
212
+ 11: 'relu_5'}
213
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
214
+
215
+
216
+ def _inception_v3(layers):
217
+ r"""Get inception v3 layers"""
218
+ inception = torchvision.models.inception_v3(pretrained=True)
219
+ network = nn.Sequential(inception.Conv2d_1a_3x3,
220
+ inception.Conv2d_2a_3x3,
221
+ inception.Conv2d_2b_3x3,
222
+ nn.MaxPool2d(kernel_size=3, stride=2),
223
+ inception.Conv2d_3b_1x1,
224
+ inception.Conv2d_4a_3x3,
225
+ nn.MaxPool2d(kernel_size=3, stride=2),
226
+ inception.Mixed_5b,
227
+ inception.Mixed_5c,
228
+ inception.Mixed_5d,
229
+ inception.Mixed_6a,
230
+ inception.Mixed_6b,
231
+ inception.Mixed_6c,
232
+ inception.Mixed_6d,
233
+ inception.Mixed_6e,
234
+ inception.Mixed_7a,
235
+ inception.Mixed_7b,
236
+ inception.Mixed_7c,
237
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)))
238
+ layer_name_mapping = {3: 'pool_1',
239
+ 6: 'pool_2',
240
+ 14: 'mixed_6e',
241
+ 18: 'pool_3'}
242
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
243
+
244
+
245
+ def _resnet50(layers):
246
+ r"""Get resnet50 layers"""
247
+ resnet50 = torchvision.models.resnet50(pretrained=True)
248
+ network = nn.Sequential(resnet50.conv1,
249
+ resnet50.bn1,
250
+ resnet50.relu,
251
+ resnet50.maxpool,
252
+ resnet50.layer1,
253
+ resnet50.layer2,
254
+ resnet50.layer3,
255
+ resnet50.layer4,
256
+ resnet50.avgpool)
257
+ layer_name_mapping = {4: 'layer_1',
258
+ 5: 'layer_2',
259
+ 6: 'layer_3',
260
+ 7: 'layer_4'}
261
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
262
+
263
+
264
+ def _robust_resnet50(layers):
265
+ r"""Get robust resnet50 layers"""
266
+ resnet50 = torchvision.models.resnet50(pretrained=False)
267
+ state_dict = torch.utils.model_zoo.load_url(
268
+ 'http://andrewilyas.com/ImageNet.pt')
269
+ new_state_dict = {}
270
+ for k, v in state_dict['model'].items():
271
+ if k.startswith('module.model.'):
272
+ new_state_dict[k[13:]] = v
273
+ resnet50.load_state_dict(new_state_dict)
274
+ network = nn.Sequential(resnet50.conv1,
275
+ resnet50.bn1,
276
+ resnet50.relu,
277
+ resnet50.maxpool,
278
+ resnet50.layer1,
279
+ resnet50.layer2,
280
+ resnet50.layer3,
281
+ resnet50.layer4,
282
+ resnet50.avgpool)
283
+ layer_name_mapping = {4: 'layer_1',
284
+ 5: 'layer_2',
285
+ 6: 'layer_3',
286
+ 7: 'layer_4'}
287
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
288
+
289
+
290
+ def _vgg_face_dag(layers):
291
+ r"""Get vgg face layers"""
292
+ network = torchvision.models.vgg16(num_classes=2622)
293
+ state_dict = torch.utils.model_zoo.load_url(
294
+ 'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/'
295
+ 'vgg_face_dag.pth')
296
+ feature_layer_name_mapping = {
297
+ 0: 'conv1_1',
298
+ 2: 'conv1_2',
299
+ 5: 'conv2_1',
300
+ 7: 'conv2_2',
301
+ 10: 'conv3_1',
302
+ 12: 'conv3_2',
303
+ 14: 'conv3_3',
304
+ 17: 'conv4_1',
305
+ 19: 'conv4_2',
306
+ 21: 'conv4_3',
307
+ 24: 'conv5_1',
308
+ 26: 'conv5_2',
309
+ 28: 'conv5_3'}
310
+ new_state_dict = {}
311
+ for k, v in feature_layer_name_mapping.items():
312
+ new_state_dict['features.' + str(k) + '.weight'] =\
313
+ state_dict[v + '.weight']
314
+ new_state_dict['features.' + str(k) + '.bias'] = \
315
+ state_dict[v + '.bias']
316
+
317
+ classifier_layer_name_mapping = {
318
+ 0: 'fc6',
319
+ 3: 'fc7',
320
+ 6: 'fc8'}
321
+ for k, v in classifier_layer_name_mapping.items():
322
+ new_state_dict['classifier.' + str(k) + '.weight'] = \
323
+ state_dict[v + '.weight']
324
+ new_state_dict['classifier.' + str(k) + '.bias'] = \
325
+ state_dict[v + '.bias']
326
+
327
+ network.load_state_dict(new_state_dict)
328
+
329
+ class Flatten(nn.Module):
330
+ r"""Flatten the tensor"""
331
+
332
+ def forward(self, x):
333
+ r"""Flatten it"""
334
+ return x.view(x.shape[0], -1)
335
+
336
+ layer_name_mapping = {
337
+ 1: 'avgpool',
338
+ 3: 'fc6',
339
+ 4: 'relu_6',
340
+ 6: 'fc7',
341
+ 7: 'relu_7',
342
+ 9: 'fc8'}
343
+ seq_layers = [network.features, network.avgpool, Flatten()]
344
+ for i in range(7):
345
+ seq_layers += [network.classifier[i]]
346
+ network = nn.Sequential(*seq_layers)
347
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
asp/util/util.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains simple helper functions """
2
+ from __future__ import print_function
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+ import importlib
8
+ import argparse
9
+ from argparse import Namespace
10
+ import torchvision
11
+ import cv2 as cv
12
+
13
+
14
+ def str2bool(v):
15
+ if isinstance(v, bool):
16
+ return v
17
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
18
+ return True
19
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
20
+ return False
21
+ else:
22
+ raise argparse.ArgumentTypeError('Boolean value expected.')
23
+
24
+
25
+ def copyconf(default_opt, **kwargs):
26
+ conf = Namespace(**vars(default_opt))
27
+ for key in kwargs:
28
+ setattr(conf, key, kwargs[key])
29
+ return conf
30
+
31
+
32
+ def find_class_in_module(target_cls_name, module):
33
+ target_cls_name = target_cls_name.replace('_', '').lower()
34
+ clslib = importlib.import_module(module)
35
+ cls = None
36
+ for name, clsobj in clslib.__dict__.items():
37
+ if name.lower() == target_cls_name:
38
+ cls = clsobj
39
+
40
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
41
+
42
+ return cls
43
+
44
+
45
+ def tensor2im(input_image, imtype=np.uint8):
46
+ """"Converts a Tensor array into a numpy image array.
47
+
48
+ Parameters:
49
+ input_image (tensor) -- the input image tensor array
50
+ imtype (type) -- the desired type of the converted numpy array
51
+ """
52
+ if not isinstance(input_image, np.ndarray):
53
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
54
+ image_tensor = input_image.data
55
+ else:
56
+ return input_image
57
+ image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
58
+ if image_numpy.shape[0] == 1: # grayscale to RGB
59
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
60
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
61
+ else: # if it is a numpy array, do nothing
62
+ image_numpy = input_image
63
+ return image_numpy.astype(imtype)
64
+
65
+
66
+ def diagnose_network(net, name='network'):
67
+ """Calculate and print the mean of average absolute(gradients)
68
+
69
+ Parameters:
70
+ net (torch network) -- Torch network
71
+ name (str) -- the name of the network
72
+ """
73
+ mean = 0.0
74
+ count = 0
75
+ for param in net.parameters():
76
+ if param.grad is not None:
77
+ mean += torch.mean(torch.abs(param.grad.data))
78
+ count += 1
79
+ if count > 0:
80
+ mean = mean / count
81
+ print(name)
82
+ print(mean)
83
+
84
+
85
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
86
+ """Save a numpy image to the disk
87
+
88
+ Parameters:
89
+ image_numpy (numpy array) -- input numpy array
90
+ image_path (str) -- the path of the image
91
+ """
92
+
93
+ image_pil = Image.fromarray(image_numpy)
94
+ h, w, _ = image_numpy.shape
95
+
96
+ if aspect_ratio is None:
97
+ pass
98
+ elif aspect_ratio > 1.0:
99
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
100
+ elif aspect_ratio < 1.0:
101
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
102
+ image_pil.save(image_path)
103
+
104
+
105
+ def print_numpy(x, val=True, shp=False):
106
+ """Print the mean, min, max, median, std, and size of a numpy array
107
+
108
+ Parameters:
109
+ val (bool) -- if print the values of the numpy array
110
+ shp (bool) -- if print the shape of the numpy array
111
+ """
112
+ x = x.astype(np.float64)
113
+ if shp:
114
+ print('shape,', x.shape)
115
+ if val:
116
+ x = x.flatten()
117
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
118
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
119
+
120
+
121
+ def mkdirs(paths):
122
+ """create empty directories if they don't exist
123
+
124
+ Parameters:
125
+ paths (str list) -- a list of directory paths
126
+ """
127
+ if isinstance(paths, list) and not isinstance(paths, str):
128
+ for path in paths:
129
+ mkdir(path)
130
+ else:
131
+ mkdir(paths)
132
+
133
+
134
+ def mkdir(path):
135
+ """create a single empty directory if it didn't exist
136
+
137
+ Parameters:
138
+ path (str) -- a single directory path
139
+ """
140
+ if not os.path.exists(path):
141
+ os.makedirs(path)
142
+
143
+
144
+ def correct_resize_label(t, size):
145
+ device = t.device
146
+ t = t.detach().cpu()
147
+ resized = []
148
+ for i in range(t.size(0)):
149
+ one_t = t[i, :1]
150
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
151
+ one_np = one_np[:, :, 0]
152
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
153
+ resized_t = torch.from_numpy(np.array(one_image)).long()
154
+ resized.append(resized_t)
155
+ return torch.stack(resized, dim=0).to(device)
156
+
157
+
158
+ def correct_resize(t, size, mode=Image.BICUBIC):
159
+ device = t.device
160
+ t = t.detach().cpu()
161
+ resized = []
162
+ for i in range(t.size(0)):
163
+ one_t = t[i:i + 1]
164
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
165
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
166
+ resized.append(resized_t)
167
+ return torch.stack(resized, dim=0).to(device)
168
+
169
+
170
+ def expand_as_one_hot(input, C, ignore_index=None):
171
+ """
172
+ Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
173
+ It is assumed that the batch dimension is present.
174
+ Args:
175
+ input (torch.Tensor): 3D/4D input image
176
+ C (int): number of channels/labels
177
+ ignore_index (int): ignore index to be kept during the expansion
178
+ Returns:
179
+ 4D/5D output torch.Tensor (NxCxSPATIAL)
180
+ """
181
+ # expand the input tensor to Nx1xSPATIAL before scattering
182
+ input = input.unsqueeze(1)
183
+ # create output tensor shape (NxCxSPATIAL)
184
+ shape = list(input.size())
185
+ shape[1] = C
186
+
187
+ if ignore_index is not None:
188
+ # create ignore_index mask for the result
189
+ mask = input.expand(shape) == ignore_index
190
+ # clone the src tensor and zero out ignore_index in the input
191
+ input = input.clone()
192
+ input[input == ignore_index] = 0
193
+ # scatter to get the one-hot tensor
194
+ result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
195
+ # bring back the ignore_index in the result
196
+ result[mask] = ignore_index
197
+ return result
198
+ else:
199
+ # scatter to get the one-hot tensor
200
+ return torch.zeros(shape).to(input.device).scatter_(1, input, 1)
201
+
202
+ def standardize(ref, I, threshold=50):
203
+ """
204
+ Transform image I to standard brightness.
205
+ Modifies the luminosity channel such that a fixed percentile is saturated.
206
+
207
+ :param I: Image uint8 RGB.
208
+ :param percentile: Percentile for luminosity saturation. At least (100 - percentile)% of pixels should be fully luminous (white).
209
+ :return: Image uint8 RGB with standardized brightness.
210
+ """
211
+ ref_m = cv.cvtColor(ref, cv.COLOR_RGB2LAB)[:, :, 0].astype(float).mean()
212
+
213
+ I_LAB = cv.cvtColor(I, cv.COLOR_RGB2LAB)
214
+ L_float = I_LAB[:, :, 0].astype(float)
215
+ tgt_m = L_float.mean()
216
+ if np.abs(tgt_m - ref_m) > threshold:
217
+ L_float = L_float - tgt_m + ref_m
218
+ I_LAB[:, :, 0] = np.clip(L_float, 0, 255).astype(np.uint8)
219
+ I = cv.cvtColor(I_LAB, cv.COLOR_LAB2RGB)
220
+ return I
asp/util/visualizer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import ntpath
5
+ import time
6
+ from . import util, html
7
+ from subprocess import Popen, PIPE
8
+
9
+ if sys.version_info[0] == 2:
10
+ VisdomExceptionBase = Exception
11
+ else:
12
+ VisdomExceptionBase = ConnectionError
13
+
14
+
15
+ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
16
+ """Save images to the disk.
17
+
18
+ Parameters:
19
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
20
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
21
+ image_path (str) -- the string is used to create image paths
22
+ aspect_ratio (float) -- the aspect ratio of saved images
23
+ width (int) -- the images will be resized to width x width
24
+
25
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
26
+ """
27
+ image_dir = webpage.get_image_dir()
28
+ short_path = ntpath.basename(image_path[0])
29
+ name = os.path.splitext(short_path)[0]
30
+
31
+ webpage.add_header(name)
32
+ ims, txts, links = [], [], []
33
+
34
+ for label, im_data in visuals.items():
35
+ im = util.tensor2im(im_data)
36
+ image_name = '%s/%s.png' % (label, name)
37
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
38
+ save_path = os.path.join(image_dir, image_name)
39
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
40
+ ims.append(image_name)
41
+ txts.append(label)
42
+ links.append(image_name)
43
+ webpage.add_images(ims, txts, links, width=width)
44
+
45
+
46
+ class Visualizer():
47
+ """This class includes several functions that can display/save images and print/save logging information.
48
+
49
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
50
+ """
51
+
52
+ def __init__(self, opt):
53
+ """Initialize the Visualizer class
54
+
55
+ Parameters:
56
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
57
+ Step 1: Cache the training/test options
58
+ Step 2: connect to a visdom server
59
+ Step 3: create an HTML object for saveing HTML filters
60
+ Step 4: create a logging file to store training losses
61
+ """
62
+ self.opt = opt # cache the option
63
+ if opt.display_id is None:
64
+ self.display_id = np.random.randint(100000) * 10 # just a random display id
65
+ else:
66
+ self.display_id = opt.display_id
67
+ self.use_html = opt.isTrain and not opt.no_html
68
+ self.win_size = opt.display_winsize
69
+ self.name = opt.name
70
+ self.port = opt.display_port
71
+ self.saved = False
72
+ if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
73
+ import visdom
74
+ self.plot_data = {}
75
+ self.ncols = opt.display_ncols
76
+ if "tensorboard_base_url" not in os.environ:
77
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
78
+ else:
79
+ self.vis = visdom.Visdom(port=2004,
80
+ base_url=os.environ['tensorboard_base_url'] + '/visdom')
81
+ if not self.vis.check_connection():
82
+ self.create_visdom_connections()
83
+
84
+ if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
85
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
86
+ self.img_dir = os.path.join(self.web_dir, 'images')
87
+ print('create web directory %s...' % self.web_dir)
88
+ util.mkdirs([self.web_dir, self.img_dir])
89
+ # create a logging file to store training losses
90
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
91
+ with open(self.log_name, "a") as log_file:
92
+ now = time.strftime("%c")
93
+ log_file.write('================ Training Loss (%s) ================\n' % now)
94
+
95
+ def reset(self):
96
+ """Reset the self.saved status"""
97
+ self.saved = False
98
+
99
+ def create_visdom_connections(self):
100
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
101
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
102
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
103
+ print('Command: %s' % cmd)
104
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
105
+
106
+ def display_current_results(self, visuals, epoch, save_result):
107
+ """Display current results on visdom; save current results to an HTML file.
108
+
109
+ Parameters:
110
+ visuals (OrderedDict) - - dictionary of images to display or save
111
+ epoch (int) - - the current epoch
112
+ save_result (bool) - - if save the current results to an HTML file
113
+ """
114
+ if self.display_id > 0: # show images in the browser using visdom
115
+ ncols = self.ncols
116
+ if ncols > 0: # show all the images in one visdom panel
117
+ ncols = min(ncols, len(visuals))
118
+ h, w = next(iter(visuals.values())).shape[:2]
119
+ table_css = """<style>
120
+ table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
121
+ table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
122
+ </style>""" % (w, h) # create a table css
123
+ # create a table of images.
124
+ title = self.name
125
+ label_html = ''
126
+ label_html_row = ''
127
+ images = []
128
+ idx = 0
129
+ for label, image in visuals.items():
130
+ image_numpy = util.tensor2im(image)
131
+ label_html_row += '<td>%s</td>' % label
132
+ images.append(image_numpy.transpose([2, 0, 1]))
133
+ idx += 1
134
+ if idx % ncols == 0:
135
+ label_html += '<tr>%s</tr>' % label_html_row
136
+ label_html_row = ''
137
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
138
+ while idx % ncols != 0:
139
+ images.append(white_image)
140
+ label_html_row += '<td></td>'
141
+ idx += 1
142
+ if label_html_row != '':
143
+ label_html += '<tr>%s</tr>' % label_html_row
144
+ try:
145
+ self.vis.images(images, ncols, 2, self.display_id + 1,
146
+ None, dict(title=title + ' images'))
147
+ label_html = '<table>%s</table>' % label_html
148
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
149
+ opts=dict(title=title + ' labels'))
150
+ except VisdomExceptionBase:
151
+ self.create_visdom_connections()
152
+
153
+ else: # show each image in a separate visdom panel;
154
+ idx = 1
155
+ try:
156
+ for label, image in visuals.items():
157
+ image_numpy = util.tensor2im(image)
158
+ self.vis.image(
159
+ image_numpy.transpose([2, 0, 1]),
160
+ self.display_id + idx,
161
+ None,
162
+ dict(title=label)
163
+ )
164
+ idx += 1
165
+ except VisdomExceptionBase:
166
+ self.create_visdom_connections()
167
+
168
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
169
+ self.saved = True
170
+ # save images to the disk
171
+ for label, image in visuals.items():
172
+ image_numpy = util.tensor2im(image)
173
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
174
+ util.save_image(image_numpy, img_path)
175
+
176
+ # update website
177
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
178
+ for n in range(epoch, 0, -1):
179
+ webpage.add_header('epoch [%d]' % n)
180
+ ims, txts, links = [], [], []
181
+
182
+ for label, image_numpy in visuals.items():
183
+ image_numpy = util.tensor2im(image)
184
+ img_path = 'epoch%.3d_%s.png' % (n, label)
185
+ ims.append(img_path)
186
+ txts.append(label)
187
+ links.append(img_path)
188
+ webpage.add_images(ims, txts, links, width=self.win_size)
189
+ webpage.save()
190
+
191
+ def plot_current_losses(self, epoch, counter_ratio, losses):
192
+ """display the current losses on visdom display: dictionary of error labels and values
193
+
194
+ Parameters:
195
+ epoch (int) -- current epoch
196
+ counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
197
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
198
+ """
199
+ if len(losses) == 0:
200
+ return
201
+
202
+ plot_name = '_'.join(list(losses.keys()))
203
+
204
+ if plot_name not in self.plot_data:
205
+ self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
206
+
207
+ plot_data = self.plot_data[plot_name]
208
+ plot_id = list(self.plot_data.keys()).index(plot_name)
209
+
210
+ plot_data['X'].append(epoch + counter_ratio)
211
+ plot_data['Y'].append([losses[k] for k in plot_data['legend']])
212
+ try:
213
+ self.vis.line(
214
+ X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
215
+ Y=np.array(plot_data['Y']),
216
+ opts={
217
+ 'title': self.name,
218
+ 'legend': plot_data['legend'],
219
+ 'xlabel': 'epoch',
220
+ 'ylabel': 'loss'},
221
+ win=self.display_id - plot_id)
222
+ except VisdomExceptionBase:
223
+ self.create_visdom_connections()
224
+
225
+ # losses: same format as |losses| of plot_current_losses
226
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
227
+ """print current losses on console; also save the losses to the disk
228
+
229
+ Parameters:
230
+ epoch (int) -- current epoch
231
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
232
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
233
+ t_comp (float) -- computational time per data point (normalized by batch_size)
234
+ t_data (float) -- data loading time per data point (normalized by batch_size)
235
+ """
236
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
237
+ for k, v in losses.items():
238
+ message += '%s: %.3f ' % (k, v)
239
+
240
+ print(message) # print the message
241
+ with open(self.log_name, "a") as log_file:
242
+ log_file.write('%s\n' % message) # save the message
main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+
4
+ import random
5
+ import torch
6
+ from torchvision.transforms.functional import to_pil_image
7
+
8
+ from types import SimpleNamespace
9
+ from PIL import Image
10
+
11
+ from asp.models.cpt_model import CPTModel
12
+ from asp.data.base_dataset import get_transform
13
+ from asp.util.general_utils import parse_args
14
+
15
+ def transform_with_seed(input_img, transform, seed=123456):
16
+ random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ return transform(input_img)
19
+
20
+ def convert_he2ihc(input_he_image_path):
21
+ input_img = Image.open(input_he_image_path).convert('RGB')
22
+
23
+ opt = SimpleNamespace(
24
+ gpu_ids=[0],
25
+ isTrain=False,
26
+ checkpoints_dir="../../checkpoints",
27
+ # name="ASP_pretrained/BCI_her2_lambda_linear",
28
+ name="ASP_pretrained/BCI_her2_zero_uniform",
29
+ preprocess="crop",
30
+ nce_layers="0,4,8,12,16",
31
+ nce_idt=False,
32
+ input_nc=3,
33
+ output_nc=3,
34
+ ngf=64,
35
+ netG="resnet_6blocks",
36
+ normG="instance",
37
+ no_dropout=True,
38
+ init_type="xavier",
39
+ init_gain=0.02,
40
+ no_antialias=False,
41
+ no_antialias_up=False,
42
+ weight_norm="spectral",
43
+ netF="mlp_sample",
44
+ netF_nc=256,
45
+ no_flip=True,
46
+ load_size=1024,
47
+ crop_size=1024,
48
+ direction="AtoB",
49
+ flip_equivariance=False,
50
+ epoch="latest",
51
+ verbose=True
52
+ )
53
+ model = CPTModel(opt)
54
+
55
+ transform = get_transform(opt)
56
+
57
+ model.setup(opt)
58
+ model.parallelize()
59
+ model.eval()
60
+
61
+ A = transform_with_seed(input_img, transform)
62
+ model.set_input({
63
+ "A": A.unsqueeze(0),
64
+ "A_paths": input_he_image_path,
65
+ "B": A.unsqueeze(0),
66
+ "B_paths": input_he_image_path,
67
+ })
68
+ model.test()
69
+ visuals = model.get_current_visuals()
70
+
71
+ output_img = to_pil_image(visuals['fake_B'].detach().cpu().squeeze(0))
72
+ print("np.shape(output_img)", np.shape(output_img))
73
+
74
+ return output_img
75
+
76
+ def main():
77
+ demo = gr.Interface(
78
+ fn=convert_he2ihc,
79
+ inputs=gr.Image(type="filepath"),
80
+ outputs=gr.Image(),
81
+ title="H&E to IHC, BIC HER2"
82
+ )
83
+
84
+ demo.launch()
85
+
86
+ if __name__ == "__main__":
87
+ args = parse_args(main)
88
+ main(**vars(args))
89
+
90
+ # python main.py -i ../../data/BCI_dataset/BCI_dataset/HE/test/00003_test_3+.png
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dominate
2
+ packaging
3
+ opencv-python
4
+ GPUtil
5
+
6
+ gradio