File size: 6,894 Bytes
d90acf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from typing import Union, List
import PIL
import numpy as np

import torch
import torchvision.transforms as T
from einops import repeat

from kandinsky3.model.unet import UNet
from kandinsky3.movq import MoVQ
from kandinsky3.condition_encoders import T5TextConditionEncoder
from kandinsky3.condition_processors import T5TextConditionProcessor
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
from kandinsky3.utils import resize_image_for_diffusion, resize_mask_for_diffusion


class Kandinsky3InpaintingPipeline:

    def __init__(
            self,
            device_map: Union[str, torch.device, dict],
            dtype_map: Union[str, torch.dtype, dict],
            unet: UNet,
            null_embedding: torch.Tensor,
            t5_processor: T5TextConditionProcessor,
            t5_encoder: T5TextConditionEncoder,
            movq: MoVQ,
    ):
        self.device_map = device_map
        self.dtype_map = dtype_map
        self.to_pil = T.ToPILImage()
        self.to_tensor = T.ToTensor()

        self.unet = unet
        self.null_embedding = null_embedding
        self.t5_processor = t5_processor
        self.t5_encoder = t5_encoder
        self.movq = movq

    def shared_step(self, batch: dict) -> dict:
        image = batch['image']
        condition_model_input = batch['text']
        negative_condition_model_input = batch['negative_text']

        bs = image.shape[0]

        masked_latent = None
        mask = batch['mask']

        if 'masked_image' in batch:
            masked_latent = batch['masked_image']
        elif self.unet.in_layer.in_channels == 9:
            masked_latent = image.masked_fill((1 - mask).bool(), 0)
        else:
            raise ValueError()

        with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
            masked_latent = self.movq.encode(masked_latent)
        mask = torch.nn.functional.interpolate(mask, size=(masked_latent.shape[2], masked_latent.shape[3]))

        with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
            context, context_mask = self.t5_encoder(condition_model_input)

        if negative_condition_model_input is not None:
            negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
        else:
            negative_context, negative_context_mask = None, None

        return {
            'context': context,
            'context_mask': context_mask,
            'negative_context': negative_context,
            'negative_context_mask': negative_context_mask,
            'image': image,
            'masked_latent': masked_latent,
            'mask': mask
        }

    def prepare_batch(
            self,
            text: str,
            negative_text: str,
            image: PIL.Image.Image,
            mask: np.ndarray,
    ) -> dict:
        condition_model_input, negative_condition_model_input = self.t5_processor.encode(
            text=text, negative_text=negative_text
        )
        batch = {
            'image': self.to_tensor(resize_image_for_diffusion(image.convert("RGB"))) * 2 - 1,
            'mask': 1 - self.to_tensor(resize_mask_for_diffusion(mask)),
            'text': condition_model_input,
            'negative_text': negative_condition_model_input
        }
        batch['mask'] = batch['mask'].type(self.dtype_map['movq'])

        batch['image'] = batch['image'].unsqueeze(0).to(self.device_map['movq'])
        batch['text']['input_ids'] = batch['text']['input_ids'].unsqueeze(0).to(self.device_map['text_encoder'])
        batch['text']['attention_mask'] = batch['text']['attention_mask'].unsqueeze(0).to(
            self.device_map['text_encoder'])
        batch['mask'] = batch['mask'].unsqueeze(0).to(self.device_map['movq'])

        if negative_condition_model_input is not None:
            batch['negative_text']['input_ids'] = batch['negative_text']['input_ids'].to(
                self.device_map['text_encoder'])
            batch['negative_text']['attention_mask'] = batch['negative_text']['attention_mask'].to(
                self.device_map['text_encoder'])

        return batch

    def __call__(
            self,
            text: str,
            image: PIL.Image.Image,
            mask: np.ndarray,
            negative_text: str = None,
            images_num: int = 1,
            bs: int = 1,
            steps: int = 50,
            guidance_weight_text: float = 4,
            eta=1.0
    ) -> List[PIL.Image.Image]:

        with torch.no_grad():
            batch = self.prepare_batch(text, negative_text, image, mask)
            processed = self.shared_step(batch)
            betas = get_named_beta_schedule('cosine', 1000)
            base_diffusion = BaseDiffusion(betas, percentile=0.95)
            times = list(range(999, 0, -1000 // steps))

            pil_images = []
            k, m = images_num // bs, images_num % bs
            for minibatch in [bs] * k + [m]:
                if minibatch == 0:
                    continue

                bs_context = repeat(processed['context'], '1 n d -> b n d', b=minibatch)
                bs_context_mask = repeat(processed['context_mask'], '1 n -> b n', b=minibatch)

                if processed['negative_context'] is not None:
                    bs_negative_context = repeat(processed['negative_context'], '1 n d -> b n d', b=minibatch)
                    bs_negative_context_mask = repeat(processed['negative_context_mask'], '1 n -> b n', b=minibatch)
                else:
                    bs_negative_context, bs_negative_context_mask = None, None

                mask = processed['mask'].repeat_interleave(minibatch, dim=0)
                masked_latent = processed['masked_latent'].repeat_interleave(minibatch, dim=0)

                minibatch = masked_latent.shape[0]

                with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
                    with torch.no_grad():
                        images = base_diffusion.p_sample_loop(
                            self.unet, (minibatch, 4, masked_latent.shape[2], masked_latent.shape[3]), times,
                            self.device_map['unet'],
                            bs_context, bs_context_mask, self.null_embedding, guidance_weight_text, eta,
                            negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
                            mask=mask, masked_latent=masked_latent, gan=False
                        )

                with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
                    images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
                    images = torch.clip((images + 1.) / 2., 0., 1.).cpu()

                for images_chunk in images.chunk(1):
                    pil_images += [self.to_pil(image) for image in images_chunk]

        return pil_images