TorchRik commited on
Commit
e523134
·
verified ·
1 Parent(s): 81bf000

Upload combined_stable_diffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. combined_stable_diffusion.py +475 -0
combined_stable_diffusion.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline, DDPMScheduler
3
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
4
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
5
+ from diffusers.image_processor import VaeImageProcessor
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from transformers import CLIPTextModel, CLIPImageProcessor, CLIPTextModelWithProjection
8
+ from diffusers.models.attention_processor import (
9
+ AttnProcessor2_0,
10
+ FusedAttnProcessor2_0,
11
+ XFormersAttnProcessor,
12
+ )
13
+
14
+
15
+ class CombinedStableDiffusionXL(
16
+ DiffusionPipeline,
17
+ PyTorchModelHubMixin
18
+ ):
19
+ """
20
+ A Stable Diffusion model wrapper that provides functionality for text-to-image synthesis,
21
+ noise scheduling, latent space manipulation, and image decoding.
22
+ """
23
+ def __init__(
24
+ self,
25
+ original_unet: torch.nn.Module,
26
+ fine_tuned_unet: torch.nn.Module,
27
+ scheduler: DDPMScheduler,
28
+ vae: torch.nn.Module,
29
+ tokenizer: CLIPTextModel,
30
+ tokenizer_2: CLIPTextModel,
31
+ text_encoder: CLIPTextModelWithProjection,
32
+ text_encoder_2: CLIPTextModelWithProjection,
33
+ ) -> None:
34
+
35
+ super().__init__()
36
+
37
+ self.register_modules(
38
+ tokenizer=tokenizer,
39
+ tokenizer_2=tokenizer_2,
40
+ text_encoder=text_encoder,
41
+ text_encoder_2=text_encoder_2,
42
+ original_unet=original_unet,
43
+ fine_tuned_unet=fine_tuned_unet,
44
+ scheduler=scheduler,
45
+ vae=vae,
46
+ )
47
+
48
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
49
+ self.image_processor = VaeImageProcessor(
50
+ vae_scale_factor=self.vae_scale_factor
51
+ )
52
+ self.resolution = 1024
53
+
54
+ def _get_negative_prompts(
55
+ self, batch_size: int
56
+ ) -> tuple[torch.Tensor, torch.Tensor]:
57
+ inputs_ids_1 = self.tokenizer(
58
+ [""] * batch_size,
59
+ max_length=self.tokenizer.model_max_length,
60
+ padding="max_length",
61
+ truncation=True,
62
+ return_tensors="pt",
63
+ ).input_ids
64
+
65
+ input_ids_2 = self.tokenizer_2(
66
+ [""] * batch_size,
67
+ max_length=self.tokenizer.model_max_length,
68
+ padding="max_length",
69
+ truncation=True,
70
+ return_tensors="pt",
71
+ ).input_ids
72
+ return inputs_ids_1, input_ids_2
73
+
74
+ def _get_encoder_hidden_states(
75
+ self,
76
+ tokenized_prompts_1: torch.Tensor,
77
+ tokenized_prompts_2: torch.Tensor,
78
+ do_classifier_free_guidance: bool = False
79
+ ) -> torch.Tensor:
80
+ text_input_ids_list = [
81
+ tokenized_prompts_1,
82
+ tokenized_prompts_2
83
+ ]
84
+ batch_size = text_input_ids_list[0].size(0)
85
+
86
+ if do_classifier_free_guidance:
87
+ negative_prompts = [
88
+ embed.to(text_input_ids_list[0].device)
89
+ for embed in self._get_negative_prompts(batch_size)
90
+ ]
91
+
92
+ text_input_ids_list = [
93
+ torch.cat(
94
+ [
95
+ negative_prompt,
96
+ text_input,
97
+ ]
98
+ )
99
+ for text_input, negative_prompt in zip(
100
+ text_input_ids_list, negative_prompts
101
+ )
102
+ ]
103
+ prompt_embeds_list = []
104
+
105
+ text_encoders = [self.text_encoder, self.text_encoder_2]
106
+ for text_encoder, text_input_ids in zip(text_encoders, text_input_ids_list):
107
+ prompt_embeds = text_encoder(
108
+ text_input_ids.to(text_encoder.device),
109
+ output_hidden_states=True,
110
+ return_dict=False,
111
+ )
112
+ pooled_prompt_embeds = prompt_embeds[0]
113
+ prompt_embeds = prompt_embeds[-1][-2]
114
+ bs_embed, seq_len, _ = prompt_embeds.shape
115
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
116
+ prompt_embeds_list.append(prompt_embeds)
117
+
118
+ prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
119
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
120
+ return prompt_embeds, pooled_prompt_embeds
121
+
122
+ def _get_unet_prediction(
123
+ self,
124
+ latent_model_input: torch.Tensor,
125
+ timestep: int,
126
+ encoder_hidden_states: torch.Tensor,
127
+ ) -> torch.Tensor:
128
+ """
129
+ Return unet noise prediction
130
+
131
+ Args:
132
+ latent_model_input (torch.Tensor): Unet latents input
133
+ timestep (int): noise scheduler timestep
134
+ encoder_hidden_states (tuple[torch.Tensor, torch.Tensor]): Text encoder hidden states
135
+
136
+ Returns:
137
+ torch.Tensor: noise prediction
138
+ """
139
+ unet = self.original_unet if self._use_original_unet else self.fine_tuned_unet
140
+
141
+ prompt_embeds, pooled_prompt_embeds = encoder_hidden_states
142
+ target_size = torch.tensor(
143
+ [
144
+ [self.resolution, self.resolution]
145
+ for _ in range(latent_model_input.size(0))
146
+ ],
147
+ device=latent_model_input.device,
148
+ dtype=torch.float32,
149
+ )
150
+ add_time_ids = torch.cat(
151
+ [target_size, torch.zeros_like(target_size), target_size], dim=1
152
+ )
153
+
154
+ unet_added_conditions = {
155
+ "time_ids": add_time_ids,
156
+ "text_embeds": pooled_prompt_embeds,
157
+ }
158
+
159
+ return unet(
160
+ latent_model_input,
161
+ timestep,
162
+ encoder_hidden_states=prompt_embeds,
163
+ added_cond_kwargs=unet_added_conditions,
164
+ ).sample
165
+
166
+ def get_noise_prediction(
167
+ self,
168
+ latents: torch.Tensor,
169
+ timestep_index: int,
170
+ encoder_hidden_states: torch.Tensor,
171
+ do_classifier_free_guidance: bool = False,
172
+ detach_main_path: bool = False,
173
+ ):
174
+ """
175
+ Return noise prediction
176
+
177
+ Args:
178
+ latents (torch.Tensor): Image latents
179
+ timestep_index (int): noise scheduler timestep index
180
+ encoder_hidden_states (torch.Tensor): Text encoder hidden states
181
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
182
+ detach_main_path (bool): Detach gradient
183
+
184
+ Returns:
185
+ torch.Tensor: noise prediction
186
+ """
187
+ timestep = self.scheduler.timesteps[timestep_index]
188
+
189
+ latent_model_input = self.scheduler.scale_model_input(
190
+ sample=torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
191
+ timestep=timestep,
192
+ )
193
+
194
+ noise_pred = self._get_unet_prediction(
195
+ latent_model_input=latent_model_input,
196
+ timestep=timestep,
197
+ encoder_hidden_states=encoder_hidden_states,
198
+ )
199
+
200
+ if do_classifier_free_guidance:
201
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
202
+ if detach_main_path:
203
+ noise_pred_text = noise_pred_text.detach()
204
+
205
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
206
+ noise_pred_text - noise_pred_uncond
207
+ )
208
+ return noise_pred
209
+
210
+ def sample_next_latents(
211
+ self,
212
+ latents: torch.Tensor,
213
+ timestep_index: int,
214
+ noise_pred: torch.Tensor,
215
+ return_pred_original: bool = False,
216
+ ) -> torch.Tensor:
217
+ """
218
+ Return next latents prediction
219
+
220
+ Args:
221
+ latents (torch.Tensor): Image latents
222
+ timestep_index (int): noise scheduler timestep index
223
+ noise_pred (torch.Tensor): noise prediction
224
+ return_pred_original (bool) Whether to sample original sample
225
+
226
+ Returns:
227
+ torch.Tensor: latent prediction
228
+ """
229
+ timestep = self.scheduler.timesteps[timestep_index]
230
+ sample = self.scheduler.step(
231
+ model_output=noise_pred, timestep=timestep, sample=latents
232
+ )
233
+ return (
234
+ sample.pred_original_sample if return_pred_original else sample.prev_sample
235
+ )
236
+
237
+ def predict_next_latents(
238
+ self,
239
+ latents: torch.Tensor,
240
+ timestep_index: int,
241
+ encoder_hidden_states: torch.Tensor,
242
+ return_pred_original: bool = False,
243
+ do_classifier_free_guidance: bool = False,
244
+ detach_main_path: bool = False,
245
+ ) -> tuple[torch.Tensor, torch.Tensor]:
246
+ """
247
+ Predicts the next latent states during the diffusion process.
248
+
249
+ Args:
250
+ latents (torch.Tensor): Current latent states.
251
+ timestep_index (int): Index of the current timestep.
252
+ encoder_hidden_states (torch.Tensor): Encoder hidden states from the text encoder.
253
+ return_pred_original (bool): Whether to return the predicted original sample.
254
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
255
+ detach_main_path (bool): Detach gradient
256
+
257
+ Returns:
258
+ tuple: Next latents and predicted noise tensor.
259
+ """
260
+
261
+ noise_pred = self.get_noise_prediction(
262
+ latents=latents,
263
+ timestep_index=timestep_index,
264
+ encoder_hidden_states=encoder_hidden_states,
265
+ do_classifier_free_guidance=do_classifier_free_guidance,
266
+ detach_main_path=detach_main_path,
267
+ )
268
+
269
+ latents = self.sample_next_latents(
270
+ latents=latents,
271
+ noise_pred=noise_pred,
272
+ timestep_index=timestep_index,
273
+ return_pred_original=return_pred_original,
274
+ )
275
+
276
+ return latents, noise_pred
277
+
278
+ def get_latents(self, batch_size: int, device: torch.device) -> torch.Tensor:
279
+ latent_resolution = int(self.resolution) // self.vae_scale_factor
280
+ return torch.randn(
281
+ (
282
+ batch_size,
283
+ self.original_unet.config.in_channels,
284
+ latent_resolution,
285
+ latent_resolution,
286
+ ),
287
+ device=device,
288
+ )
289
+
290
+ def do_k_diffusion_steps(
291
+ self,
292
+ start_timestep_index: int,
293
+ end_timestep_index: int,
294
+ latents: torch.Tensor,
295
+ encoder_hidden_states: torch.Tensor,
296
+ return_pred_original: bool = False,
297
+ do_classifier_free_guidance: bool = False,
298
+ detach_main_path: bool = False,
299
+ ) -> tuple[torch.Tensor, torch.Tensor]:
300
+ """
301
+ Performs multiple diffusion steps between specified timesteps.
302
+
303
+ Args:
304
+ start_timestep_index (int): Starting timestep index.
305
+ end_timestep_index (int): Ending timestep index.
306
+ latents (torch.Tensor): Initial latents.
307
+ encoder_hidden_states (torch.Tensor): Encoder hidden states.
308
+ return_pred_original (bool): Whether to return the predicted original sample.
309
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
310
+ detach_main_path (bool): Detach gradient
311
+
312
+ Returns:
313
+ tuple: Resulting latents and encoder hidden states.
314
+ """
315
+ assert start_timestep_index <= end_timestep_index
316
+
317
+ for timestep_index in range(start_timestep_index, end_timestep_index - 1):
318
+ latents, _ = self.predict_next_latents(
319
+ latents=latents,
320
+ timestep_index=timestep_index,
321
+ encoder_hidden_states=encoder_hidden_states,
322
+ return_pred_original=False,
323
+ do_classifier_free_guidance=do_classifier_free_guidance,
324
+ detach_main_path=detach_main_path,
325
+ )
326
+ res, _ = self.predict_next_latents(
327
+ latents=latents,
328
+ timestep_index=end_timestep_index - 1,
329
+ encoder_hidden_states=encoder_hidden_states,
330
+ return_pred_original=return_pred_original,
331
+ do_classifier_free_guidance=do_classifier_free_guidance,
332
+ )
333
+ return res, encoder_hidden_states
334
+
335
+ def upcast_vae(self):
336
+ dtype = self.vae.dtype
337
+ self.vae.to(dtype=torch.float32)
338
+ use_torch_2_0_or_xformers = isinstance(
339
+ self.vae.decoder.mid_block.attentions[0].processor,
340
+ (
341
+ AttnProcessor2_0,
342
+ XFormersAttnProcessor,
343
+ FusedAttnProcessor2_0,
344
+ ),
345
+ )
346
+ if use_torch_2_0_or_xformers:
347
+ self.vae.post_quant_conv.to(dtype)
348
+ self.vae.decoder.conv_in.to(dtype)
349
+ self.vae.decoder.mid_block.to(dtype)
350
+
351
+ @torch.no_grad()
352
+ def __call__(
353
+ self,
354
+ prompt: str | list[str],
355
+ num_inference_steps=40,
356
+ original_unet_steps=35,
357
+ resolution=1024,
358
+ guidance_scale=5,
359
+ output_type: str = "pil",
360
+ return_dict: bool = True,
361
+ generator=None,
362
+ ):
363
+ self.guidance_scale = guidance_scale
364
+ self.resolution = resolution
365
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
366
+
367
+ tokenized_prompts_1 = self.tokenizer(
368
+ prompt,
369
+ max_length=self.tokenizer.model_max_length,
370
+ padding="max_length",
371
+ truncation=True,
372
+ return_tensors="pt",
373
+ ).input_ids
374
+
375
+ tokenized_prompts_2 = self.tokenizer_2(
376
+ prompt,
377
+ max_length=self.tokenizer_2.model_max_length,
378
+ padding="max_length",
379
+ truncation=True,
380
+ return_tensors="pt",
381
+ ).input_ids
382
+
383
+ original_encoder_hidden_states = self._get_encoder_hidden_states(
384
+ tokenized_prompts_1=tokenized_prompts_1,
385
+ tokenized_prompts_2=tokenized_prompts_2,
386
+ do_classifier_free_guidance=True
387
+ )
388
+ fine_tuned_encoder_hidden_states = self._get_encoder_hidden_states(
389
+ tokenized_prompts_1=tokenized_prompts_1,
390
+ tokenized_prompts_2=tokenized_prompts_2,
391
+ do_classifier_free_guidance=False
392
+ )
393
+
394
+ latent_resolution = int(resolution) // self.vae_scale_factor
395
+ latents = torch.randn(
396
+ (
397
+ batch_size,
398
+ self.original_unet.config.in_channels,
399
+ latent_resolution,
400
+ latent_resolution,
401
+ ),
402
+ device=self.device,
403
+ )
404
+
405
+ self.scheduler.set_timesteps(
406
+ num_inference_steps,
407
+ device=self.device
408
+ )
409
+
410
+ self._use_original_unet = True
411
+ latents, _ = self.do_k_diffusion_steps(
412
+ start_timestep_index=0,
413
+ end_timestep_index=original_unet_steps,
414
+ latents=latents,
415
+ encoder_hidden_states=original_encoder_hidden_states,
416
+ return_pred_original=False,
417
+ do_classifier_free_guidance=True,
418
+ )
419
+
420
+ self._use_original_unet = False
421
+ latents, _ = self.do_k_diffusion_steps(
422
+ start_timestep_index=original_unet_steps,
423
+ end_timestep_index=num_inference_steps,
424
+ latents=latents,
425
+ encoder_hidden_states=fine_tuned_encoder_hidden_states,
426
+ return_pred_original=False,
427
+ do_classifier_free_guidance=False,
428
+ )
429
+
430
+
431
+ if not output_type == "latent":
432
+ # make sure the VAE is in float32 mode, as it overflows in float16
433
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
434
+
435
+ if needs_upcasting:
436
+ self.upcast_vae()
437
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
438
+ elif latents.dtype != self.vae.dtype:
439
+ if torch.backends.mps.is_available():
440
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
441
+ self.vae = self.vae.to(latents.dtype)
442
+
443
+ # unscale/denormalize the latents
444
+ # denormalize with the mean and std if available and not None
445
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
446
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
447
+ if has_latents_mean and has_latents_std:
448
+ latents_mean = (
449
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
450
+ )
451
+ latents_std = (
452
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
453
+ )
454
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
455
+ else:
456
+ latents = latents / self.vae.config.scaling_factor
457
+
458
+ image = self.vae.decode(latents, return_dict=False)[0]
459
+
460
+ # cast back to fp16 if needed
461
+ if needs_upcasting:
462
+ self.vae.to(dtype=torch.float16)
463
+ else:
464
+ image = latents
465
+
466
+ if not output_type == "latent":
467
+ image = self.image_processor.postprocess(image, output_type=output_type)
468
+
469
+ # Offload all models
470
+ self.maybe_free_model_hooks()
471
+
472
+ if not return_dict:
473
+ return (image,)
474
+
475
+ return StableDiffusionXLPipelineOutput(images=image)