Spaces:
Build error
Build error
File size: 5,222 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
from .base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import util.util as util
import os
import torch
class Pix2pixDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--no_pairing_check', action='store_true',
help='If specified, skip sanity check of correct label-image file pairing')
return parser
def initialize(self, opt):
self.opt = opt
label_paths, image_paths, instance_paths = self.get_paths(opt)
util.natural_sort(label_paths)
util.natural_sort(image_paths)
if not opt.no_instance:
util.natural_sort(instance_paths)
label_paths = label_paths[:opt.max_dataset_size]
image_paths = image_paths[:opt.max_dataset_size]
instance_paths = instance_paths[:opt.max_dataset_size]
if not opt.no_pairing_check:
for path1, path2 in zip(label_paths, image_paths):
assert self.paths_match(path1, path2), \
"The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (
path1, path2)
self.label_paths = label_paths
self.image_paths = image_paths
self.instance_paths = instance_paths
size = len(self.label_paths)
self.dataset_size = size
def get_paths(self, opt):
label_paths = []
image_paths = []
instance_paths = []
assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
return label_paths, image_paths, instance_paths
def paths_match(self, path1, path2):
filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
return filename1_without_ext == filename2_without_ext
def __getitem__(self, index):
# Label Image
label_path = self.label_paths[index]
label = Image.open(label_path)
params = get_params(self.opt, label.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
label_tensor = transform_label(label) * 255.0
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
# input image (real images)
image_path = self.image_paths[index]
assert self.paths_match(label_path, image_path), \
"The label_path %s and image_path %s don't match." % \
(label_path, image_path)
image = Image.open(image_path)
image = image.convert('RGB')
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
# if using instance maps
if self.opt.no_instance:
instance_tensor = 0
else:
instance_path = self.instance_paths[index]
instance = Image.open(instance_path)
if instance.mode == 'L':
instance_tensor = transform_label(instance) * 255
instance_tensor = instance_tensor.long()
else:
instance_tensor = transform_label(instance)
input_dict = {'label': label_tensor,
'instance': instance_tensor,
'image': image_tensor,
'path': image_path,
}
# Give subclasses a chance to modify the final color_texture
self.postprocess(input_dict)
return input_dict
def postprocess(self, input_dict):
return input_dict
def __len__(self):
return self.dataset_size
# Our codes get input images and labels
def get_input_by_names(self, image_path, image, label_img):
label = Image.fromarray(label_img)
params = get_params(self.opt, label.size)
transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
label_tensor = transform_label(label) * 255.0
label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc
label_tensor.unsqueeze_(0)
# input image (real images)]
# image = Image.open(image_path)
# image = image.convert('RGB')
transform_image = get_transform(self.opt, params)
image_tensor = transform_image(image)
image_tensor.unsqueeze_(0)
# if using instance maps
if self.opt.no_instance:
instance_tensor = torch.Tensor([0])
input_dict = {'label': label_tensor,
'instance': instance_tensor,
'image': image_tensor,
'path': image_path,
}
# Give subclasses a chance to modify the final color_texture
self.postprocess(input_dict)
return input_dict
|