# #!/usr/bin/python # # -*- encoding: utf-8 -*- # # import torch # from torch.utils.data import Dataset # import torchvision.transforms as transforms # # import os.path as osp # import os # from PIL import Image # import numpy as np # import json # import cv2 # # from .transform import * # # # # class FaceMask(Dataset): # def __init__(self, rootpth, cropsize=(640, 480), mode='scripts', *args, **kwargs): # super(FaceMask, self).__init__(*args, **kwargs) # assert mode in ('scripts', 'val', 'test') # self.mode = mode # self.ignore_lb = 255 # self.rootpth = rootpth # # self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img')) # # # pre-processing # self.to_tensor = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # ]) # self.trans_train = Compose([ # ColorJitter( # brightness=0.5, # contrast=0.5, # saturation=0.5), # HorizontalFlip(), # RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)), # RandomCrop(cropsize) # ]) # # def __getitem__(self, idx): # impth = self.imgs[idx] # img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth)) # img = img.resize((512, 512), Image.BILINEAR) # label = Image.open(osp.join(self.rootpth, 'Mask_less', impth[:-3]+'png')).convert('P') # # print(np.unique(np.array(label))) # if self.mode == 'scripts': # im_lb = dict(im=img, lb=label) # im_lb = self.trans_train(im_lb) # img, label = im_lb['im'], im_lb['lb'] # img = self.to_tensor(img) # label = np.array(label).astype(np.int64)[np.newaxis, :] # return img, label # # def __len__(self): # return len(self.imgs) # # # if __name__ == "__main__": # face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img' # face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno' # mask_path = '/home/zll/data/CelebAMask-HQ/mask' # counter = 0 # total = 0 # for i in range(15): # # files = os.listdir(osp.join(face_sep_mask, str(i))) # # atts = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', # 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'] # # for j in range(i*2000, (i+1)*2000): # # mask = np.zeros((512, 512)) # # for l, att in enumerate(atts, 1): # total += 1 # file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png']) # path = osp.join(face_sep_mask, str(i), file_name) # # if os.path.exists(path): # counter += 1 # sep_mask = np.array(Image.open(path).convert('P')) # # print(np.unique(sep_mask)) # # mask[sep_mask == 225] = l # cv2.imwrite('{}/{}.png'.format(mask_path, j), mask) # print(j) # # print(counter, total) # # # # # # # # # # # # # #