|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
import numpy.ma as ma
|
|
import torch.utils.data as data
|
|
import copy
|
|
from torchvision import transforms
|
|
import scipy.io as scio
|
|
import torchvision.datasets as dset
|
|
import random
|
|
import scipy.misc
|
|
import scipy.io as scio
|
|
import os
|
|
from PIL import ImageEnhance
|
|
from PIL import ImageFilter
|
|
|
|
class SegDataset(data.Dataset):
|
|
def __init__(self, root_dir, txtlist, use_noise, length):
|
|
self.path = []
|
|
self.real_path = []
|
|
self.use_noise = use_noise
|
|
self.root = root_dir
|
|
input_file = open(txtlist)
|
|
while 1:
|
|
input_line = input_file.readline()
|
|
if not input_line:
|
|
break
|
|
if input_line[-1:] == '\n':
|
|
input_line = input_line[:-1]
|
|
self.path.append(copy.deepcopy(input_line))
|
|
if input_line[:5] == 'data/':
|
|
self.real_path.append(copy.deepcopy(input_line))
|
|
input_file.close()
|
|
|
|
self.length = length
|
|
self.data_len = len(self.path)
|
|
self.back_len = len(self.real_path)
|
|
|
|
self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
|
|
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
self.back_front = np.array([[1 for i in range(640)] for j in range(480)])
|
|
|
|
def __getitem__(self, idx):
|
|
index = random.randint(0, self.data_len - 10)
|
|
|
|
label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index])))
|
|
meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index]))
|
|
if not self.use_noise:
|
|
rgb = np.array(Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB"))
|
|
else:
|
|
rgb = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB")))
|
|
|
|
if self.path[index][:8] == 'data_syn':
|
|
rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB")
|
|
rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8))
|
|
rgb = np.array(self.trancolor(rgb))
|
|
seed = random.randint(0, self.back_len - 10)
|
|
back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB")))
|
|
back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed])))
|
|
mask = ma.getmaskarray(ma.masked_equal(label, 0))
|
|
back = np.transpose(back, (2, 0, 1))
|
|
rgb = np.transpose(rgb, (2, 0, 1))
|
|
rgb = rgb + np.random.normal(loc=0.0, scale=5.0, size=rgb.shape)
|
|
rgb = back * mask + rgb
|
|
label = back_label * mask + label
|
|
rgb = np.transpose(rgb, (1, 2, 0))
|
|
|
|
|
|
|
|
if self.use_noise:
|
|
choice = random.randint(0, 3)
|
|
if choice == 0:
|
|
rgb = np.fliplr(rgb)
|
|
label = np.fliplr(label)
|
|
elif choice == 1:
|
|
rgb = np.flipud(rgb)
|
|
label = np.flipud(label)
|
|
elif choice == 2:
|
|
rgb = np.fliplr(rgb)
|
|
rgb = np.flipud(rgb)
|
|
label = np.fliplr(label)
|
|
label = np.flipud(label)
|
|
|
|
|
|
obj = meta['cls_indexes'].flatten().astype(np.int32)
|
|
obj = np.append(obj, [0], axis=0)
|
|
target = copy.deepcopy(label)
|
|
|
|
rgb = np.transpose(rgb, (2, 0, 1))
|
|
rgb = self.norm(torch.from_numpy(rgb.astype(np.float32)))
|
|
target = torch.from_numpy(target.astype(np.int64))
|
|
|
|
return rgb, target
|
|
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|