import torch import numpy as np from PIL import Image import numpy.ma as ma import torch.utils.data as data import copy from torchvision import transforms import scipy.io as scio import torchvision.datasets as dset import random import scipy.misc import scipy.io as scio import os from PIL import ImageEnhance from PIL import ImageFilter class SegDataset(data.Dataset): def __init__(self, root_dir, txtlist, use_noise, length): self.path = [] self.real_path = [] self.use_noise = use_noise self.root = root_dir input_file = open(txtlist) while 1: input_line = input_file.readline() if not input_line: break if input_line[-1:] == '\n': input_line = input_line[:-1] self.path.append(copy.deepcopy(input_line)) if input_line[:5] == 'data/': self.real_path.append(copy.deepcopy(input_line)) input_file.close() self.length = length self.data_len = len(self.path) self.back_len = len(self.real_path) self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05) self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.back_front = np.array([[1 for i in range(640)] for j in range(480)]) def __getitem__(self, idx): index = random.randint(0, self.data_len - 10) label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index]))) meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index])) if not self.use_noise: rgb = np.array(Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB")) else: rgb = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB"))) if self.path[index][:8] == 'data_syn': rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB") rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8)) rgb = np.array(self.trancolor(rgb)) seed = random.randint(0, self.back_len - 10) back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB"))) back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed]))) mask = ma.getmaskarray(ma.masked_equal(label, 0)) back = np.transpose(back, (2, 0, 1)) rgb = np.transpose(rgb, (2, 0, 1)) rgb = rgb + np.random.normal(loc=0.0, scale=5.0, size=rgb.shape) rgb = back * mask + rgb label = back_label * mask + label rgb = np.transpose(rgb, (1, 2, 0)) #scipy.misc.imsave('embedding_final/rgb_{0}.png'.format(index), rgb) #scipy.misc.imsave('embedding_final/label_{0}.png'.format(index), label) if self.use_noise: choice = random.randint(0, 3) if choice == 0: rgb = np.fliplr(rgb) label = np.fliplr(label) elif choice == 1: rgb = np.flipud(rgb) label = np.flipud(label) elif choice == 2: rgb = np.fliplr(rgb) rgb = np.flipud(rgb) label = np.fliplr(label) label = np.flipud(label) obj = meta['cls_indexes'].flatten().astype(np.int32) obj = np.append(obj, [0], axis=0) target = copy.deepcopy(label) rgb = np.transpose(rgb, (2, 0, 1)) rgb = self.norm(torch.from_numpy(rgb.astype(np.float32))) target = torch.from_numpy(target.astype(np.int64)) return rgb, target def __len__(self): return self.length