Spaces:
Runtime error
Runtime error
from PIL import Image | |
import os | |
import numpy as np | |
from torch.utils.data import Dataset | |
import torchvision.transforms as T | |
import random | |
import torch | |
import json | |
class MMCelebAHQ(Dataset): | |
def __init__( | |
self, | |
root="data/mmcelebahq", | |
condition_size: int = 512, | |
target_size: int = 512, | |
condition_type: str = "depth", | |
drop_text_prob: float = 0.1, | |
drop_image_prob: float = 0.1, | |
return_pil_image: bool = False, | |
position_scale=1.0, | |
): | |
self.root = root | |
self.face_paths, self.mask_paths, self.prompts = self.get_face_mask_prompt() | |
self.condition_size = condition_size | |
self.target_size = target_size | |
self.condition_type = condition_type | |
self.drop_text_prob = drop_text_prob | |
self.drop_image_prob = drop_image_prob | |
self.return_pil_image = return_pil_image | |
self.position_scale = position_scale | |
self.to_tensor = T.ToTensor() | |
def get_face_mask_prompt(self): | |
face_paths = [ | |
os.path.join(self.root, "face", f"{i}.jpg") for i in range(0, 27000) | |
] | |
mask_paths = [ | |
os.path.join(self.root, "mask", f"{i}.png") for i in range(0, 27000) | |
] | |
with open(os.path.join(self.root, "text.json"), mode="r") as f: | |
prompts = json.load(f) | |
return face_paths, mask_paths, prompts | |
def __len__(self): | |
return len(self.face_paths) | |
def __getitem__(self, idx): | |
image = Image.open(self.face_paths[idx]).convert("RGB") | |
prompts = self.prompts[f"{idx}.jpg"] | |
description = random.choices(prompts, k=1)[0].strip() | |
enable_scale = random.random() < 1 | |
if not enable_scale: | |
condition_size = int(self.condition_size * self.position_scale) | |
position_scale = 1.0 | |
else: | |
condition_size = self.condition_size | |
position_scale = self.position_scale | |
# Get the condition image | |
position_delta = np.array([0, 0]) | |
mask = np.array(Image.open(self.mask_paths[idx])) | |
mask_list = [self.to_tensor(Image.open(self.mask_paths[idx]).convert("RGB"))] | |
for i in range(19): | |
local_mask = np.zeros_like(mask) | |
local_mask[mask == i] = 255 | |
drop_image = random.random() < self.drop_image_prob | |
if drop_image: | |
local_mask = np.zeros_like(mask) | |
local_mask_rgb = Image.fromarray(local_mask).convert("RGB") | |
local_mask_tensor = self.to_tensor(local_mask_rgb) | |
mask_list.append(local_mask_tensor) | |
condition_img = torch.stack(mask_list,dim=0) | |
# Randomly drop text or image | |
drop_text = random.random() < self.drop_text_prob | |
# drop_image = random.random() < self.drop_image_prob | |
if drop_text: | |
description = "" | |
return { | |
"image": self.to_tensor(image), | |
"condition": condition_img, | |
# "condition": self.to_tensor(condition_img), | |
"condition_type": self.condition_type, | |
"description": description, | |
"position_delta": position_delta, | |
**({"pil_image": [image, condition_img]} if self.return_pil_image else {}), | |
**({"position_scale": position_scale} if position_scale != 1.0 else {}), | |
} | |