import os import glob import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data as data from PIL import Image from torchvision import transforms, utils class MyDataSet(data.Dataset): def __init__(self, image_dir=None, label_dir=None, output_size=(256, 256), noise_in=None, training_set=True, video_data=False, train_split=0.9): self.image_dir = image_dir self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) self.resize = transforms.Compose([ transforms.Resize(output_size), transforms.ToTensor() ]) self.noise_in = noise_in self.video_data = video_data self.random_rotation = transforms.Compose([ transforms.Resize(output_size), transforms.RandomPerspective(distortion_scale=0.05, p=1.0), transforms.ToTensor() ]) # load image file train_len = None self.length = 0 self.image_dir = image_dir if image_dir is not None: img_list = [glob.glob1(self.image_dir, ext) for ext in ['*jpg','*png']] image_list = [item for sublist in img_list for item in sublist] image_list.sort() train_len = int(train_split*len(image_list)) if training_set: self.image_list = image_list[:train_len] else: self.image_list = image_list[train_len:] self.length = len(self.image_list) # load label file self.label_dir = label_dir if label_dir is not None: self.seeds = np.load(label_dir) if train_len is None: train_len = int(train_split*len(self.seeds)) if training_set: self.seeds = self.seeds[:train_len] else: self.seeds = self.seeds[train_len:] if self.length == 0: self.length = len(self.seeds) def __len__(self): return self.length def __getitem__(self, idx): img = None if self.image_dir is not None: img_name = os.path.join(self.image_dir, self.image_list[idx]) image = Image.open(img_name) img = self.resize(image) if img.size(0) == 1: img = torch.cat((img, img, img), dim=0) img = self.normalize(img) # generate image if self.label_dir is not None: torch.manual_seed(self.seeds[idx]) z = torch.randn(1, 512)[0] if self.noise_in is None: n = [torch.randn(1, 1)] else: n = [torch.randn(noise.size())[0] for noise in self.noise_in] if img is None: return z, n else: return z, img, n else: return img class Car_DataSet(data.Dataset): def __init__(self, image_dir=None, label_dir=None, output_size=(512, 512), noise_in=None, training_set=True, video_data=False, train_split=0.9): self.image_dir = image_dir self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) self.resize = transforms.Compose([ transforms.Resize((384, 512)), transforms.Pad(padding=(0, 64, 0, 64)), transforms.ToTensor() ]) self.noise_in = noise_in self.video_data = video_data self.random_rotation = transforms.Compose([ transforms.Resize(output_size), transforms.RandomPerspective(distortion_scale=0.05, p=1.0), transforms.ToTensor() ]) # load image file train_len = None self.length = 0 self.image_dir = image_dir if image_dir is not None: img_list = [glob.glob1(self.image_dir, ext) for ext in ['*jpg','*png']] image_list = [item for sublist in img_list for item in sublist] image_list.sort() train_len = int(train_split*len(image_list)) if training_set: self.image_list = image_list[:train_len] else: self.image_list = image_list[train_len:] self.length = len(self.image_list) # load label file self.label_dir = label_dir if label_dir is not None: self.seeds = np.load(label_dir) if train_len is None: train_len = int(train_split*len(self.seeds)) if training_set: self.seeds = self.seeds[:train_len] else: self.seeds = self.seeds[train_len:] if self.length == 0: self.length = len(self.seeds) def __len__(self): return self.length def __getitem__(self, idx): img = None if self.image_dir is not None: img_name = os.path.join(self.image_dir, self.image_list[idx]) image = Image.open(img_name) img = self.resize(image) if img.size(0) == 1: img = torch.cat((img, img, img), dim=0) img = self.normalize(img) if self.video_data: img_2 = self.random_rotation(image) img_2 = self.normalize(img_2) img_2 = torch.where(img_2 > -1, img_2, img) img = torch.cat([img, img_2], dim=0) # generate image if self.label_dir is not None: torch.manual_seed(self.seeds[idx]) z = torch.randn(1, 512)[0] n = [torch.randn_like(noise[0]) for noise in self.noise_in] if img is None: return z, n else: return z, img, n else: return img