xwwu commited on
Commit
501c549
·
1 Parent(s): 9fd3a00

Upload 3 files

Browse files
Files changed (3) hide show
  1. app-2.py +173 -0
  2. pipeline_controlnet_blip_diffusion.py +653 -0
  3. requirements.txt +7 -0
app-2.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import torch
4
+
5
+ from PIL import Image
6
+ import numpy as np
7
+ from io import BytesIO
8
+ import os
9
+
10
+ from diffusers.utils import load_image
11
+ from diffusers import ControlNetModel
12
+ import numpy as np
13
+ import torch
14
+ from diffusers.image_processor import VaeImageProcessor
15
+ from PIL import Image
16
+ from pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
17
+
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+ blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
21
+ "Salesforce/blipdiffusion-controlnet"
22
+ )
23
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint")
24
+
25
+ blip_diffusion_pipe.controlnet = controlnet
26
+ blip_diffusion_pipe.to(device)
27
+
28
+ def make_inpaint_condition(image, image_mask):
29
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
30
+ image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
31
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
32
+ image[image_mask > 0.5] = -1 # set as masked pixel
33
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
34
+ image = torch.from_numpy(image)
35
+ return image
36
+
37
+ css='''
38
+ .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
39
+ .image_upload{min-height:500px}
40
+ .image_upload [data-testid="image"], .image_upload [data-testid="image"] > div{min-height: 500px}
41
+ .image_upload [data-testid="target"], .image_upload [data-testid="target"] > div{min-height: 500px}
42
+ .image_upload .touch-none{display: flex}
43
+ #output_image{min-height:500px;max-height=500px;}
44
+ '''
45
+
46
+
47
+ def create_demo():
48
+ # load information from users
49
+ HEIGHT, WIDTH=512,512
50
+ with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace","monospace"],
51
+ primary_hue="lime",
52
+ secondary_hue="emerald",
53
+ neutral_hue="slate",
54
+ ), css=css) as demo:
55
+ gr.Markdown('# BLIP-Diffusion')
56
+ with gr.Accordion('Instructions', open=False):
57
+ gr.Markdown('1. Upload src image and draw mask')
58
+ gr.Markdown('2. Upload tgt image')
59
+ gr.Markdown('3. Input name of tgt object and description')
60
+ gr.Markdown('4. Click `Generate` when it is ready!')
61
+
62
+ with gr.Group():
63
+ with gr.Box():
64
+ with gr.Column():
65
+ with gr.Row() as main_blocks:
66
+ #
67
+ with gr.Column() as step_1:
68
+ gr.Markdown('### Source Input and Add Mask')
69
+ image = gr.Image(source='upload',
70
+ shape=[HEIGHT,WIDTH],
71
+ type='pil',#numpy',
72
+ elem_classes="image_upload",
73
+ label='Source Image',
74
+ tool='sketch',
75
+ brush_radius=60).style(height=500)
76
+ src_input=image
77
+ text_prompt = gr.Textbox(label='Prompt')
78
+ run_button = gr.Button(label='Generate', value='Generate', variant="primary")
79
+ #
80
+ with gr.Column() as step_2:
81
+ gr.Markdown('### Target Input')
82
+ target = gr.Image(source='upload',
83
+ shape=[HEIGHT,WIDTH],
84
+ type='pil',#numpy',
85
+ elem_classes="image_upload",
86
+ label='Target Image'
87
+ ).style(height=500)
88
+ tgt_input=target
89
+ style_subject = gr.Textbox(label='Target Object')
90
+
91
+ with gr.Row() as output_blocks:
92
+ with gr.Column() as output_step:
93
+ gr.Markdown('### Output')
94
+ output_image = gr.Gallery(
95
+ label="Generated images",
96
+ show_label=False,
97
+ elem_id="output_image",
98
+ ).style(height=500,containter=True)
99
+
100
+ with gr.Accordion('Advanced options', open=False):
101
+ num_inference_steps = gr.Slider(label='Steps',
102
+ minimum=1,
103
+ maximum=100,
104
+ value=50,
105
+ step=1)
106
+ guidance_scale = gr.Slider(label='Text Guidance Scale',
107
+ minimum=0.1,
108
+ maximum=30.0,
109
+ value=7.5,
110
+ step=0.1)
111
+ seed = gr.Slider(label='Seed',
112
+ minimum=-1,
113
+ maximum=2147483647,
114
+ step=1,
115
+ randomize=True)
116
+
117
+ # Model
118
+ inputs = [
119
+ src_input,
120
+ tgt_input,
121
+ text_prompt,
122
+ style_subject,
123
+ num_inference_steps,
124
+ guidance_scale,
125
+ seed,
126
+ ]
127
+
128
+ def generate(src_input,
129
+ tgt_input,
130
+ text_prompt,
131
+ style_subject,
132
+ num_inference_steps,
133
+ guidance_scale,
134
+ seed,
135
+ ):
136
+ if src_input is None or tgt_input is None:
137
+ gr.Error("You must upload an image first.")
138
+ return {output_image : None,}
139
+ # model part
140
+ tgt_subject = style_subject
141
+ generator = torch.Generator(device="cpu").manual_seed(seed)
142
+ init_image = src_input['image']
143
+ cldm_cond_image = src_input['mask']
144
+ control_image = make_inpaint_condition(init_image, cldm_cond_image)
145
+ style_image = tgt_input
146
+
147
+ negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
148
+
149
+ output = blip_diffusion_pipe(
150
+ text_prompt,
151
+ style_image,
152
+ control_image,
153
+ style_subject,
154
+ tgt_subject,
155
+ generator=generator,
156
+ image=init_image,
157
+ mask_image=cldm_cond_image,
158
+ guidance_scale=guidance_scale,
159
+ num_inference_steps=num_inference_steps,
160
+ neg_prompt=negative_prompt,
161
+ height=HEIGHT,
162
+ width=WIDTH,
163
+ ).images
164
+ return {output_image : output,}
165
+
166
+ run_button.click(fn=generate, inputs=inputs, outputs=[output_image])
167
+ return demo
168
+
169
+ if __name__ == '__main__':
170
+ demo = create_demo()
171
+ demo.queue().launch()
172
+
173
+
pipeline_controlnet_blip_diffusion.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Salesforce.com, inc.
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import List, Optional, Union
16
+
17
+ import PIL.Image
18
+ import torch
19
+ from transformers import CLIPTokenizer
20
+
21
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
22
+ from diffusers.schedulers import PNDMScheduler
23
+ from diffusers.utils import (
24
+ logging,
25
+ replace_example_docstring,
26
+ )
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
29
+ from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
30
+ from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
32
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+ EXAMPLE_DOC_STRING = """
37
+ Examples:
38
+ ```py
39
+ >>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
40
+ >>> from diffusers.utils import load_image
41
+ >>> from controlnet_aux import CannyDetector
42
+ >>> import torch
43
+
44
+ >>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
45
+ ... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
46
+ ... ).to("cuda")
47
+
48
+ >>> style_subject = "flower"
49
+ >>> tgt_subject = "teapot"
50
+ >>> text_prompt = "on a marble table"
51
+
52
+ >>> cldm_cond_image = load_image(
53
+ ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
54
+ ... ).resize((512, 512))
55
+ >>> canny = CannyDetector()
56
+ >>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
57
+ >>> style_image = load_image(
58
+ ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
59
+ ... )
60
+ >>> guidance_scale = 7.5
61
+ >>> num_inference_steps = 50
62
+ >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
63
+
64
+
65
+ >>> output = blip_diffusion_pipe(
66
+ ... text_prompt,
67
+ ... style_image,
68
+ ... cldm_cond_image,
69
+ ... style_subject,
70
+ ... tgt_subject,
71
+ ... guidance_scale=guidance_scale,
72
+ ... num_inference_steps=num_inference_steps,
73
+ ... neg_prompt=negative_prompt,
74
+ ... height=512,
75
+ ... width=512,
76
+ ... ).images
77
+ >>> output[0].save("image.png")
78
+ ```
79
+ """
80
+
81
+
82
+ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
83
+ """
84
+ Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
85
+
86
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
87
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
88
+
89
+ Args:
90
+ tokenizer ([`CLIPTokenizer`]):
91
+ Tokenizer for the text encoder
92
+ text_encoder ([`ContextCLIPTextModel`]):
93
+ Text encoder to encode the text prompt
94
+ vae ([`AutoencoderKL`]):
95
+ VAE model to map the latents to the image
96
+ unet ([`UNet2DConditionModel`]):
97
+ Conditional U-Net architecture to denoise the image embedding.
98
+ scheduler ([`PNDMScheduler`]):
99
+ A scheduler to be used in combination with `unet` to generate image latents.
100
+ qformer ([`Blip2QFormerModel`]):
101
+ QFormer model to get multi-modal embeddings from the text and image.
102
+ controlnet ([`ControlNetModel`]):
103
+ ControlNet model to get the conditioning image embedding.
104
+ image_processor ([`BlipImageProcessor`]):
105
+ Image Processor to preprocess and postprocess the image.
106
+ ctx_begin_pos (int, `optional`, defaults to 2):
107
+ Position of the context token in the text encoder.
108
+ """
109
+
110
+ model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
111
+
112
+ def __init__(
113
+ self,
114
+ tokenizer: CLIPTokenizer,
115
+ text_encoder: ContextCLIPTextModel,
116
+ vae: AutoencoderKL,
117
+ unet: UNet2DConditionModel,
118
+ scheduler: PNDMScheduler,
119
+ qformer: Blip2QFormerModel,
120
+ controlnet: ControlNetModel,
121
+ image_processor: BlipImageProcessor,
122
+ ctx_begin_pos: int = 2,
123
+ mean: List[float] = None,
124
+ std: List[float] = None,
125
+ ):
126
+ super().__init__()
127
+
128
+ self.register_modules(
129
+ tokenizer=tokenizer,
130
+ text_encoder=text_encoder,
131
+ vae=vae,
132
+ unet=unet,
133
+ scheduler=scheduler,
134
+ qformer=qformer,
135
+ controlnet=controlnet,
136
+ image_processor=image_processor,
137
+ )
138
+ # copy control net
139
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
140
+ self.init_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
141
+ self.mask_processor = VaeImageProcessor(
142
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
143
+ )
144
+ self.control_image_processor = VaeImageProcessor(
145
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
146
+ )
147
+ self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
148
+
149
+ def get_query_embeddings(self, input_image, src_subject):
150
+ return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
151
+
152
+ # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
153
+ def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
154
+ rv = []
155
+ for prompt, tgt_subject in zip(prompts, tgt_subjects):
156
+ prompt = f"a {tgt_subject} {prompt.strip()}"
157
+ # a trick to amplify the prompt
158
+ rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
159
+
160
+ return rv
161
+
162
+ # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
163
+ def prepare_latents_old(
164
+ self,
165
+ batch_size,
166
+ num_channels,
167
+ height,
168
+ width,
169
+ dtype,
170
+ device,
171
+ generator,
172
+ latents=None,
173
+ image=None):
174
+ shape = (batch_size, num_channels, height, width)
175
+ if isinstance(generator, list) and len(generator) != batch_size:
176
+ raise ValueError(
177
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
178
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
179
+ )
180
+
181
+ if latents is None:
182
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
183
+ else:
184
+ latents = latents.to(device=device, dtype=dtype)
185
+
186
+ # scale the initial noise by the standard deviation required by the scheduler
187
+ latents = latents * self.scheduler.init_noise_sigma
188
+ return latents
189
+
190
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
191
+ def prepare_latents(
192
+ self,
193
+ batch_size,
194
+ num_channels_latents,
195
+ height,
196
+ width,
197
+ dtype,
198
+ device,
199
+ generator,
200
+ latents=None,
201
+ image=None,
202
+ timestep=None,
203
+ is_strength_max=True,
204
+ return_noise=False,
205
+ return_image_latents=False,
206
+ ):
207
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
208
+ if isinstance(generator, list) and len(generator) != batch_size:
209
+ raise ValueError(
210
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
211
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
212
+ )
213
+
214
+ if (image is None or timestep is None) and not is_strength_max:
215
+ raise ValueError(
216
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
217
+ "However, either the image or the noise timestep has not been provided."
218
+ )
219
+
220
+ if return_image_latents or (latents is None and not is_strength_max):
221
+ image = image.to(device=device, dtype=dtype)
222
+
223
+ if image.shape[1] == 4:
224
+ image_latents = image
225
+ else:
226
+ image_latents = self._encode_vae_image(image=image, generator=generator)
227
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
228
+
229
+ if latents is None:
230
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
231
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
232
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
233
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
234
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
235
+ else:
236
+ noise = latents.to(device)
237
+ latents = noise * self.scheduler.init_noise_sigma
238
+
239
+ outputs = (latents,)
240
+
241
+ if return_noise:
242
+ outputs += (noise,)
243
+
244
+ if return_image_latents:
245
+ outputs += (image_latents,)
246
+
247
+ return outputs
248
+
249
+ def encode_prompt(self, query_embeds, prompt, device=None):
250
+ device = device or self._execution_device
251
+
252
+ # embeddings for prompt, with query_embeds as context
253
+ max_len = self.text_encoder.text_model.config.max_position_embeddings
254
+ max_len -= self.qformer.config.num_query_tokens
255
+
256
+ tokenized_prompt = self.tokenizer(
257
+ prompt,
258
+ padding="max_length",
259
+ truncation=True,
260
+ max_length=max_len,
261
+ return_tensors="pt",
262
+ ).to(device)
263
+
264
+ batch_size = query_embeds.shape[0]
265
+ ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
266
+
267
+ text_embeddings = self.text_encoder(
268
+ input_ids=tokenized_prompt.input_ids,
269
+ ctx_embeddings=query_embeds,
270
+ ctx_begin_pos=ctx_begin_pos,
271
+ )[0]
272
+
273
+ return text_embeddings
274
+
275
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
276
+ def get_timesteps(self, num_inference_steps, strength, device):
277
+ # get the original timestep using init_timestep
278
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
279
+
280
+ t_start = max(num_inference_steps - init_timestep, 0)
281
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
282
+
283
+ return timesteps, num_inference_steps - t_start
284
+
285
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
286
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
287
+ if isinstance(generator, list):
288
+ image_latents = [
289
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
290
+ for i in range(image.shape[0])
291
+ ]
292
+ image_latents = torch.cat(image_latents, dim=0)
293
+ else:
294
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
295
+
296
+ image_latents = self.vae.config.scaling_factor * image_latents
297
+
298
+ return image_latents
299
+
300
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
301
+ def prepare_mask_latents(
302
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
303
+ ):
304
+ # resize the mask to latents shape as we concatenate the mask to the latents
305
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
306
+ # and half precision
307
+ mask = torch.nn.functional.interpolate(
308
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
309
+ )
310
+ mask = mask.to(device=device, dtype=dtype)
311
+
312
+ masked_image = masked_image.to(device=device, dtype=dtype)
313
+
314
+ if masked_image.shape[1] == 4:
315
+ masked_image_latents = masked_image
316
+ else:
317
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
318
+
319
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
320
+ if mask.shape[0] < batch_size:
321
+ if not batch_size % mask.shape[0] == 0:
322
+ raise ValueError(
323
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
324
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
325
+ " of masks that you pass is divisible by the total requested batch size."
326
+ )
327
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
328
+ if masked_image_latents.shape[0] < batch_size:
329
+ if not batch_size % masked_image_latents.shape[0] == 0:
330
+ raise ValueError(
331
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
332
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
333
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
334
+ )
335
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
336
+
337
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
338
+ masked_image_latents = (
339
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
340
+ )
341
+
342
+ # aligning device to prevent device errors when concating it with the latent model input
343
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
344
+ return mask, masked_image_latents
345
+
346
+ # Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
347
+ def prepare_control_image(
348
+ self,
349
+ image,
350
+ width,
351
+ height,
352
+ batch_size,
353
+ num_images_per_prompt,
354
+ device,
355
+ dtype,
356
+ do_classifier_free_guidance=False,
357
+ ):
358
+ '''
359
+ image = self.control_image_processor.preprocess(
360
+ image,
361
+ height=height,
362
+ width=width,
363
+ #size={"width": width, "height": height},
364
+ do_rescale=True,
365
+ do_center_crop=False,
366
+ do_normalize=False,
367
+ return_tensors="pt",
368
+ )["pixel_values"].to(device)
369
+ '''
370
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
371
+ image_batch_size = image.shape[0]
372
+
373
+ if image_batch_size == 1:
374
+ repeat_by = batch_size
375
+ else:
376
+ # image batch size is the same as prompt batch size
377
+ repeat_by = num_images_per_prompt
378
+
379
+ image = image.repeat_interleave(repeat_by, dim=0)
380
+
381
+ image = image.to(device=device, dtype=dtype)
382
+
383
+ if do_classifier_free_guidance:
384
+ image = torch.cat([image] * 2)
385
+
386
+ return image
387
+
388
+ @torch.no_grad()
389
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
390
+ def __call__(
391
+ self,
392
+ prompt: List[str],
393
+ reference_image: PIL.Image.Image,
394
+ condtioning_image: PIL.Image.Image,
395
+ source_subject_category: List[str],
396
+ target_subject_category: List[str],
397
+ image: PipelineImageInput = None,
398
+ mask_image: PipelineImageInput = None,
399
+ latents: Optional[torch.FloatTensor] = None,
400
+ guidance_scale: float = 7.5,
401
+ height: int = 512,
402
+ width: int = 512,
403
+ num_inference_steps: int = 50,
404
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
405
+ neg_prompt: Optional[str] = "",
406
+ prompt_strength: float = 1.0,
407
+ strength: float = 1.0,
408
+ num_images_per_prompt: Optional[int] = 1,
409
+ prompt_reps: int = 20,
410
+ output_type: Optional[str] = "pil",
411
+ return_dict: bool = True,
412
+ ):
413
+ """
414
+ Function invoked when calling the pipeline for generation.
415
+
416
+ Args:
417
+ prompt (`List[str]`):
418
+ The prompt or prompts to guide the image generation.
419
+ reference_image (`PIL.Image.Image`):
420
+ The reference image to condition the generation on.
421
+ condtioning_image (`PIL.Image.Image`):
422
+ The conditioning canny edge image to condition the generation on.
423
+ source_subject_category (`List[str]`):
424
+ The source subject category.
425
+ target_subject_category (`List[str]`):
426
+ The target subject category.
427
+ latents (`torch.FloatTensor`, *optional*):
428
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
429
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
430
+ tensor will ge generated by random sampling.
431
+ guidance_scale (`float`, *optional*, defaults to 7.5):
432
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
433
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
434
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
435
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
436
+ usually at the expense of lower image quality.
437
+ height (`int`, *optional*, defaults to 512):
438
+ The height of the generated image.
439
+ width (`int`, *optional*, defaults to 512):
440
+ The width of the generated image.
441
+ seed (`int`, *optional*, defaults to 42):
442
+ The seed to use for random generation.
443
+ num_inference_steps (`int`, *optional*, defaults to 50):
444
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
445
+ expense of slower inference.
446
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
447
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
448
+ to make generation deterministic.
449
+ neg_prompt (`str`, *optional*, defaults to ""):
450
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
451
+ if `guidance_scale` is less than `1`).
452
+ prompt_strength (`float`, *optional*, defaults to 1.0):
453
+ The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
454
+ to amplify the prompt.
455
+ prompt_reps (`int`, *optional*, defaults to 20):
456
+ The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
457
+ Examples:
458
+
459
+ Returns:
460
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
461
+ """
462
+ device = self._execution_device
463
+
464
+ reference_image = self.image_processor.preprocess(
465
+ reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
466
+ )["pixel_values"]
467
+ reference_image = reference_image.to(device)
468
+
469
+ if isinstance(prompt, str):
470
+ prompt = [prompt]
471
+ if isinstance(source_subject_category, str):
472
+ source_subject_category = [source_subject_category]
473
+ if isinstance(target_subject_category, str):
474
+ target_subject_category = [target_subject_category]
475
+
476
+ batch_size = len(prompt)
477
+
478
+ prompt = self._build_prompt(
479
+ prompts=prompt,
480
+ tgt_subjects=target_subject_category,
481
+ prompt_strength=prompt_strength,
482
+ prompt_reps=prompt_reps,
483
+ )
484
+ query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
485
+ text_embeddings = self.encode_prompt(query_embeds, prompt, device)
486
+ # 3. unconditional embedding
487
+ do_classifier_free_guidance = guidance_scale > 1.0
488
+ if do_classifier_free_guidance:
489
+ max_length = self.text_encoder.text_model.config.max_position_embeddings
490
+
491
+ uncond_input = self.tokenizer(
492
+ [neg_prompt] * batch_size,
493
+ padding="max_length",
494
+ max_length=max_length,
495
+ return_tensors="pt",
496
+ )
497
+ uncond_embeddings = self.text_encoder(
498
+ input_ids=uncond_input.input_ids.to(device),
499
+ ctx_embeddings=None,
500
+ )[0]
501
+ # For classifier free guidance, we need to do two forward passes.
502
+ # Here we concatenate the unconditional and text embeddings into a single batch
503
+ # to avoid doing two forward passes
504
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
505
+
506
+ # 4. Set condition image
507
+ cond_image = self.prepare_control_image(
508
+ image=condtioning_image,
509
+ width=width,
510
+ height=height,
511
+ batch_size=batch_size,
512
+ num_images_per_prompt=1,
513
+ device=device,
514
+ dtype=self.controlnet.dtype,
515
+ do_classifier_free_guidance=do_classifier_free_guidance,
516
+ )
517
+
518
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
519
+ # set init image
520
+ init_image = self.init_processor.preprocess(image, height=height, width=width)
521
+ init_image = init_image.to(dtype=torch.float32)
522
+
523
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
524
+
525
+ masked_image = init_image * (mask < 0.5)
526
+ _, _, height, width = init_image.shape
527
+
528
+ # 5. Set timesteps
529
+ extra_set_kwargs = {}
530
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
531
+ timesteps, num_inference_steps = self.get_timesteps(
532
+ num_inference_steps=num_inference_steps, strength=strength, device=device
533
+ )
534
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
535
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
536
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
537
+ is_strength_max = strength == 1.0
538
+
539
+ # 6. Prepare latent variables
540
+ num_channels_latents = self.vae.config.latent_channels
541
+ num_channels_unet = self.unet.config.in_channels
542
+ return_image_latents = num_channels_unet == 4
543
+
544
+ # latents
545
+ scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
546
+ '''
547
+ latents = self.prepare_latents(
548
+ batch_size=batch_size,
549
+ num_channels=self.unet.config.in_channels,
550
+ height=height // scale_down_factor,
551
+ width=width // scale_down_factor,
552
+ generator=generator,
553
+ latents=latents,
554
+ dtype=self.unet.dtype,
555
+ device=device,
556
+ image=init_image,
557
+ )
558
+ '''
559
+ latents_outputs = self.prepare_latents(
560
+ batch_size,
561
+ num_channels_latents,
562
+ height,
563
+ width,
564
+ text_embeddings.dtype,
565
+ device,
566
+ generator,
567
+ latents,
568
+ image=init_image,
569
+ timestep=latent_timestep,
570
+ is_strength_max=is_strength_max,
571
+ return_noise=True,
572
+ return_image_latents=return_image_latents,
573
+ )
574
+
575
+ if return_image_latents:
576
+ latents, noise, image_latents = latents_outputs
577
+ else:
578
+ latents, noise = latents_outputs
579
+
580
+ # 7. Prepare mask latent variables
581
+ mask, masked_image_latents = self.prepare_mask_latents(
582
+ mask,
583
+ masked_image,
584
+ batch_size,
585
+ height,
586
+ width,
587
+ text_embeddings.dtype,
588
+ device,
589
+ generator,
590
+ do_classifier_free_guidance,
591
+ )
592
+
593
+ # 8. Denoising loop
594
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
595
+ # expand the latents if we are doing classifier free guidance
596
+ do_classifier_free_guidance = guidance_scale > 1.0
597
+
598
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
599
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
600
+
601
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
602
+ latent_model_input,
603
+ t,
604
+ encoder_hidden_states=text_embeddings,
605
+ controlnet_cond=cond_image,
606
+ return_dict=False,
607
+ )
608
+
609
+ noise_pred = self.unet(
610
+ latent_model_input,
611
+ timestep=t,
612
+ encoder_hidden_states=text_embeddings,
613
+ down_block_additional_residuals=down_block_res_samples,
614
+ mid_block_additional_residual=mid_block_res_sample,
615
+ )["sample"]
616
+
617
+ # perform guidance
618
+ if do_classifier_free_guidance:
619
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
620
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
621
+
622
+ latents = self.scheduler.step(
623
+ noise_pred,
624
+ t,
625
+ latents,
626
+ )["prev_sample"]
627
+
628
+ if num_channels_unet == 4:
629
+ init_latents_proper = image_latents
630
+ if do_classifier_free_guidance:
631
+ init_mask, _ = mask.chunk(2)
632
+ else:
633
+ init_mask = mask
634
+
635
+ if i < len(timesteps) - 1:
636
+ noise_timestep = timesteps[i + 1]
637
+ init_latents_proper = self.scheduler.add_noise(
638
+ init_latents_proper, noise, torch.tensor([noise_timestep])
639
+ )
640
+
641
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
642
+
643
+
644
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
645
+ image = self.image_processor.postprocess(image, output_type=output_type)
646
+
647
+ # Offload all models
648
+ self.maybe_free_model_hooks()
649
+
650
+ if not return_dict:
651
+ return (image,)
652
+
653
+ return ImagePipelineOutput(images=image)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch==2.0.1
3
+ -e git+https://github.com/huggingface/diffusers.git#egg=diffusers
4
+ pillow
5
+ numpy
6
+ gradio
7
+ accelerate