XavierJiezou's picture
Upload folder using huggingface_hub
48ed5ae verified
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset
import torchvision.transforms as T
import random
import torch
import json
class MMCelebAHQ(Dataset):
def __init__(
self,
root="data/mmcelebahq",
condition_size: int = 512,
target_size: int = 512,
condition_type: str = "depth",
drop_text_prob: float = 0.1,
drop_image_prob: float = 0.1,
return_pil_image: bool = False,
position_scale=1.0,
):
self.root = root
self.face_paths, self.mask_paths, self.prompts = self.get_face_mask_prompt()
self.condition_size = condition_size
self.target_size = target_size
self.condition_type = condition_type
self.drop_text_prob = drop_text_prob
self.drop_image_prob = drop_image_prob
self.return_pil_image = return_pil_image
self.position_scale = position_scale
self.to_tensor = T.ToTensor()
def get_face_mask_prompt(self):
face_paths = [
os.path.join(self.root, "face", f"{i}.jpg") for i in range(0, 27000)
]
mask_paths = [
os.path.join(self.root, "mask", f"{i}.png") for i in range(0, 27000)
]
with open(os.path.join(self.root, "text.json"), mode="r") as f:
prompts = json.load(f)
return face_paths, mask_paths, prompts
def __len__(self):
return len(self.face_paths)
def __getitem__(self, idx):
image = Image.open(self.face_paths[idx]).convert("RGB")
prompts = self.prompts[f"{idx}.jpg"]
description = random.choices(prompts, k=1)[0].strip()
enable_scale = random.random() < 1
if not enable_scale:
condition_size = int(self.condition_size * self.position_scale)
position_scale = 1.0
else:
condition_size = self.condition_size
position_scale = self.position_scale
# Get the condition image
position_delta = np.array([0, 0])
mask = np.array(Image.open(self.mask_paths[idx]))
mask_list = [self.to_tensor(Image.open(self.mask_paths[idx]).convert("RGB"))]
for i in range(19):
local_mask = np.zeros_like(mask)
local_mask[mask == i] = 255
drop_image = random.random() < self.drop_image_prob
if drop_image:
local_mask = np.zeros_like(mask)
local_mask_rgb = Image.fromarray(local_mask).convert("RGB")
local_mask_tensor = self.to_tensor(local_mask_rgb)
mask_list.append(local_mask_tensor)
condition_img = torch.stack(mask_list,dim=0)
# Randomly drop text or image
drop_text = random.random() < self.drop_text_prob
# drop_image = random.random() < self.drop_image_prob
if drop_text:
description = ""
return {
"image": self.to_tensor(image),
"condition": condition_img,
# "condition": self.to_tensor(condition_img),
"condition_type": self.condition_type,
"description": description,
"position_delta": position_delta,
**({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
**({"position_scale": position_scale} if position_scale != 1.0 else {}),
}