File size: 3,359 Bytes
48ed5ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset
import torchvision.transforms as T
import random
import torch
import json


class MMCelebAHQ(Dataset):
    def __init__(
        self,
        root="data/mmcelebahq",
        condition_size: int = 512,
        target_size: int = 512,
        condition_type: str = "depth",
        drop_text_prob: float = 0.1,
        drop_image_prob: float = 0.1,
        return_pil_image: bool = False,
        position_scale=1.0,
    ):
        self.root = root
        self.face_paths, self.mask_paths, self.prompts = self.get_face_mask_prompt()
        self.condition_size = condition_size
        self.target_size = target_size
        self.condition_type = condition_type
        self.drop_text_prob = drop_text_prob
        self.drop_image_prob = drop_image_prob
        self.return_pil_image = return_pil_image
        self.position_scale = position_scale

        self.to_tensor = T.ToTensor()

    def get_face_mask_prompt(self):
        face_paths = [
            os.path.join(self.root, "face", f"{i}.jpg") for i in range(0, 27000)
        ]
        mask_paths = [
            os.path.join(self.root, "mask", f"{i}.png") for i in range(0, 27000)
        ]
        with open(os.path.join(self.root, "text.json"), mode="r") as f:
            prompts = json.load(f)
        return face_paths, mask_paths, prompts

    def __len__(self):
        return len(self.face_paths)

    def __getitem__(self, idx):
        image = Image.open(self.face_paths[idx]).convert("RGB")
        prompts = self.prompts[f"{idx}.jpg"]
        description = random.choices(prompts, k=1)[0].strip()
        enable_scale = random.random() < 1
        if not enable_scale:
            condition_size = int(self.condition_size * self.position_scale)
            position_scale = 1.0
        else:
            condition_size = self.condition_size
            position_scale = self.position_scale

        # Get the condition image
        position_delta = np.array([0, 0])

        mask = np.array(Image.open(self.mask_paths[idx]))
        mask_list = [self.to_tensor(Image.open(self.mask_paths[idx]).convert("RGB"))]
        for i in range(19):
            local_mask = np.zeros_like(mask)
            local_mask[mask == i] = 255

            drop_image = random.random() < self.drop_image_prob
            if drop_image:
                local_mask = np.zeros_like(mask)

            local_mask_rgb = Image.fromarray(local_mask).convert("RGB")
            local_mask_tensor = self.to_tensor(local_mask_rgb)
            mask_list.append(local_mask_tensor)
        condition_img = torch.stack(mask_list,dim=0)


        # Randomly drop text or image
        drop_text = random.random() < self.drop_text_prob
        # drop_image = random.random() < self.drop_image_prob
        if drop_text:
            description = ""

        return {
            "image": self.to_tensor(image),
            "condition": condition_img,
            # "condition": self.to_tensor(condition_img),
            "condition_type": self.condition_type,
            "description": description,
            "position_delta": position_delta,
            **({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
            **({"position_scale": position_scale} if position_scale != 1.0 else {}),
        }