omer11a commited on
Commit
c89d950
·
1 Parent(s): 86a22a6

Updated diffusers code

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_xl_opt.py +462 -198
pipeline_stable_diffusion_xl_opt.py CHANGED
@@ -13,14 +13,24 @@
13
  # limitations under the License.
14
 
15
  import inspect
16
- import os
17
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
 
19
  import torch
20
- from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
 
 
 
 
 
 
21
 
22
- from diffusers.image_processor import VaeImageProcessor
23
- from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
 
 
 
 
 
24
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
  from diffusers.models.attention_processor import (
26
  AttnProcessor2_0,
@@ -28,22 +38,33 @@ from diffusers.models.attention_processor import (
28
  LoRAXFormersAttnProcessor,
29
  XFormersAttnProcessor,
30
  )
 
31
  from diffusers.schedulers import KarrasDiffusionSchedulers
32
  from diffusers.utils import (
33
- is_accelerate_available,
34
- is_accelerate_version,
35
  is_invisible_watermark_available,
 
36
  logging,
37
- randn_tensor,
38
  replace_example_docstring,
 
 
39
  )
40
- from diffusers.pipeline_utils import DiffusionPipeline
41
- from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
 
42
 
43
 
44
  if is_invisible_watermark_available():
45
  from .watermark import StableDiffusionXLWatermarker
46
 
 
 
 
 
 
 
 
47
 
48
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
 
@@ -79,7 +100,58 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
79
  return noise_cfg
80
 
81
 
82
- class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  r"""
84
  Pipeline for text-to-image generation using Stable Diffusion XL.
85
 
@@ -87,11 +159,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
87
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
88
 
89
  In addition the pipeline inherits the following loading methods:
90
- - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
91
  - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
92
 
93
  as well as the following saving methods:
94
- - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
95
 
96
  Args:
97
  vae ([`AutoencoderKL`]):
@@ -116,8 +188,34 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
116
  scheduler ([`SchedulerMixin`]):
117
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
118
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
 
 
 
 
 
 
 
119
  """
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def __init__(
122
  self,
123
  vae: AutoencoderKL,
@@ -127,6 +225,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
127
  tokenizer_2: CLIPTokenizer,
128
  unet: UNet2DConditionModel,
129
  scheduler: KarrasDiffusionSchedulers,
 
 
130
  force_zeros_for_empty_prompt: bool = True,
131
  add_watermarker: Optional[bool] = None,
132
  ):
@@ -140,10 +240,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
140
  tokenizer_2=tokenizer_2,
141
  unet=unet,
142
  scheduler=scheduler,
 
 
143
  )
144
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
145
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
146
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
 
147
  self.default_sample_size = self.unet.config.sample_size
148
 
149
  add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -186,36 +289,6 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
186
  """
187
  self.vae.disable_tiling()
188
 
189
- def enable_model_cpu_offload(self, gpu_id=0):
190
- r"""
191
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
192
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
193
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
194
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
195
- """
196
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
197
- from accelerate import cpu_offload_with_hook
198
- else:
199
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
200
-
201
- device = torch.device(f"cuda:{gpu_id}")
202
-
203
- if self.device.type != "cpu":
204
- self.to("cpu", silence_dtype_warnings=True)
205
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
206
-
207
- model_sequence = (
208
- [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
209
- )
210
- model_sequence.extend([self.unet, self.vae])
211
-
212
- hook = None
213
- for cpu_offloaded_model in model_sequence:
214
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
215
-
216
- # We'll offload the last model manually.
217
- self.final_offload_hook = hook
218
-
219
  def encode_prompt(
220
  self,
221
  prompt: str,
@@ -230,6 +303,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
230
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
231
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
232
  lora_scale: Optional[float] = None,
 
233
  ):
234
  r"""
235
  Encodes the prompt into text encoder hidden states.
@@ -269,17 +343,33 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
269
  input argument.
270
  lora_scale (`float`, *optional*):
271
  A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
 
 
 
272
  """
273
  device = device or self._execution_device
274
 
275
  # set lora scale so that monkey patched LoRA
276
  # function of text encoder can correctly access it
277
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
278
  self._lora_scale = lora_scale
279
 
280
- if prompt is not None and isinstance(prompt, str):
281
- batch_size = 1
282
- elif prompt is not None and isinstance(prompt, list):
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  batch_size = len(prompt)
284
  else:
285
  batch_size = prompt_embeds.shape[0]
@@ -292,6 +382,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
292
 
293
  if prompt_embeds is None:
294
  prompt_2 = prompt_2 or prompt
 
 
295
  # textual inversion: procecss multi-vector tokens if necessary
296
  prompt_embeds_list = []
297
  prompts = [prompt, prompt_2]
@@ -319,29 +411,15 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
319
  f" {tokenizer.model_max_length} tokens: {removed_text}"
320
  )
321
 
322
- prompt_embeds = text_encoder(
323
- text_input_ids.to(device),
324
- output_hidden_states=True,
325
- )
326
 
327
  # We are only ALWAYS interested in the pooled output of the final text encoder
328
  pooled_prompt_embeds = prompt_embeds[0]
329
- ### TODO: remove
330
- null_text_inputs = tokenizer(
331
- ['a realistic photo of an empty background'] * batch_size,
332
- padding="max_length",
333
- max_length=tokenizer.model_max_length,
334
- truncation=True,
335
- return_tensors="pt",
336
- )
337
- null_input_ids = null_text_inputs.input_ids
338
- null_prompt_embeds = text_encoder(
339
- null_input_ids.to(device),
340
- output_hidden_states=True,
341
- )
342
- pooled_prompt_embeds = null_prompt_embeds[0]
343
- ### TODO: remove
344
- prompt_embeds = prompt_embeds.hidden_states[-2]
345
 
346
  prompt_embeds_list.append(prompt_embeds)
347
 
@@ -356,14 +434,18 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
356
  negative_prompt = negative_prompt or ""
357
  negative_prompt_2 = negative_prompt_2 or negative_prompt
358
 
 
 
 
 
 
 
359
  uncond_tokens: List[str]
360
  if prompt is not None and type(prompt) is not type(negative_prompt):
361
  raise TypeError(
362
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
363
  f" {type(prompt)}."
364
  )
365
- elif isinstance(negative_prompt, str):
366
- uncond_tokens = [negative_prompt, negative_prompt_2]
367
  elif batch_size != len(negative_prompt):
368
  raise ValueError(
369
  f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -399,7 +481,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
399
 
400
  negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
401
 
402
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
 
 
 
 
403
  bs_embed, seq_len, _ = prompt_embeds.shape
404
  # duplicate text embeddings for each generation per prompt, using mps friendly method
405
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -408,7 +494,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
408
  if do_classifier_free_guidance:
409
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
410
  seq_len = negative_prompt_embeds.shape[1]
411
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
 
 
 
 
 
412
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
413
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
414
 
@@ -420,8 +511,32 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
420
  bs_embed * num_images_per_prompt, -1
421
  )
422
 
 
 
 
 
 
 
 
 
 
 
423
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
426
  def prepare_extra_step_kwargs(self, generator, eta):
427
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -453,18 +568,24 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
453
  negative_prompt_embeds=None,
454
  pooled_prompt_embeds=None,
455
  negative_pooled_prompt_embeds=None,
 
456
  ):
457
  if height % 8 != 0 or width % 8 != 0:
458
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
459
 
460
- if (callback_steps is None) or (
461
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
462
- ):
463
  raise ValueError(
464
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
465
  f" {type(callback_steps)}."
466
  )
467
 
 
 
 
 
 
 
 
468
  if prompt is not None and prompt_embeds is not None:
469
  raise ValueError(
470
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -531,11 +652,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
531
  latents = latents * self.scheduler.init_noise_sigma
532
  return latents
533
 
534
- def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
 
 
535
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
536
 
537
  passed_add_embed_dim = (
538
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
539
  )
540
  expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
541
 
@@ -567,7 +690,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
567
  self.vae.decoder.conv_in.to(dtype)
568
  self.vae.decoder.mid_block.to(dtype)
569
 
570
- def update_loss(self, latents, i, t, prompt_embeds, cross_attention_kwargs, add_text_embeds, add_time_ids):
571
  def forward_pass(latent_model_input):
572
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
573
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
@@ -575,7 +698,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
575
  latent_model_input,
576
  t,
577
  encoder_hidden_states=prompt_embeds,
578
- cross_attention_kwargs=cross_attention_kwargs,
 
579
  added_cond_kwargs=added_cond_kwargs,
580
  return_dict=False,
581
  )
@@ -583,6 +707,94 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
583
 
584
  return self.editor.update_loss(forward_pass, latents, i)
585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  @torch.no_grad()
587
  @replace_example_docstring(EXAMPLE_DOC_STRING)
588
  def __call__(
@@ -592,6 +804,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
592
  height: Optional[int] = None,
593
  width: Optional[int] = None,
594
  num_inference_steps: int = 50,
 
595
  denoising_end: Optional[float] = None,
596
  guidance_scale: float = 5.0,
597
  negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -604,15 +817,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
604
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
605
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
606
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
607
  output_type: Optional[str] = "pil",
608
  return_dict: bool = True,
609
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
610
- callback_steps: int = 1,
611
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
612
  guidance_rescale: float = 0.0,
613
  original_size: Optional[Tuple[int, int]] = None,
614
  crops_coords_top_left: Tuple[int, int] = (0, 0),
615
  target_size: Optional[Tuple[int, int]] = None,
 
 
 
 
 
 
 
616
  ):
617
  r"""
618
  Function invoked when calling the pipeline for generation.
@@ -625,12 +844,22 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
625
  The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
626
  used in both text-encoders
627
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
628
- The height in pixels of the generated image.
 
 
 
629
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
630
- The width in pixels of the generated image.
 
 
 
631
  num_inference_steps (`int`, *optional*, defaults to 50):
632
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
633
  expense of slower inference.
 
 
 
 
634
  denoising_end (`float`, *optional*):
635
  When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
636
  completed before it is intentionally prematurely terminated. As a result, the returned sample will
@@ -677,30 +906,25 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
677
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
678
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
679
  input argument.
 
680
  output_type (`str`, *optional*, defaults to `"pil"`):
681
  The output format of the generate image. Choose between
682
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
683
  return_dict (`bool`, *optional*, defaults to `True`):
684
  Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
685
  of a plain tuple.
686
- callback (`Callable`, *optional*):
687
- A function that will be called every `callback_steps` steps during inference. The function will be
688
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
689
- callback_steps (`int`, *optional*, defaults to 1):
690
- The frequency at which the `callback` function will be called. If not specified, the callback will be
691
- called at every step.
692
  cross_attention_kwargs (`dict`, *optional*):
693
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
694
  `self.processor` in
695
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
696
- guidance_rescale (`float`, *optional*, defaults to 0.7):
697
  Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
698
  Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
699
  [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
700
  Guidance rescale factor should fix overexposure when using zero terminal SNR.
701
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
702
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
703
- `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
704
  explained in section 2.2 of
705
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
706
  crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
@@ -710,8 +934,32 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
710
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
711
  target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
712
  For most cases, `target_size` should be set to the desired height and width of the generated image. If
713
- not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
714
  section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
 
716
  Examples:
717
 
@@ -720,6 +968,23 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
720
  [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
721
  `tuple`. When returning a tuple, the first element is a list with the generated images.
722
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
  # 0. Default height and width to unet
724
  height = height or self.default_sample_size * self.vae_scale_factor
725
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -740,8 +1005,15 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
740
  negative_prompt_embeds,
741
  pooled_prompt_embeds,
742
  negative_pooled_prompt_embeds,
 
743
  )
744
 
 
 
 
 
 
 
745
  # 2. Define call parameters
746
  if prompt is not None and isinstance(prompt, str):
747
  batch_size = 1
@@ -752,15 +1024,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
752
 
753
  device = self._execution_device
754
 
755
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
756
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
757
- # corresponds to doing no classifier free guidance.
758
- do_classifier_free_guidance = guidance_scale > 1.0
759
-
760
  # 3. Encode input prompt
761
- text_encoder_lora_scale = (
762
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
763
  )
 
764
  (
765
  prompt_embeds,
766
  negative_prompt_embeds,
@@ -771,20 +1039,19 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
771
  prompt_2=prompt_2,
772
  device=device,
773
  num_images_per_prompt=num_images_per_prompt,
774
- do_classifier_free_guidance=do_classifier_free_guidance,
775
  negative_prompt=negative_prompt,
776
  negative_prompt_2=negative_prompt_2,
777
  prompt_embeds=prompt_embeds,
778
  negative_prompt_embeds=negative_prompt_embeds,
779
  pooled_prompt_embeds=pooled_prompt_embeds,
780
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
781
- lora_scale=text_encoder_lora_scale,
 
782
  )
783
 
784
  # 4. Prepare timesteps
785
- self.scheduler.set_timesteps(num_inference_steps, device=device)
786
-
787
- timesteps = self.scheduler.timesteps
788
 
789
  # 5. Prepare latent variables
790
  num_channels_latents = self.unet.config.in_channels
@@ -804,165 +1071,162 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
804
 
805
  # 7. Prepare added time ids & embeddings
806
  add_text_embeds = pooled_prompt_embeds
 
 
 
 
 
807
  add_time_ids = self._get_add_time_ids(
808
- original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
 
 
 
 
809
  )
 
 
 
 
 
 
 
 
 
 
810
 
811
- if do_classifier_free_guidance:
812
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
813
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
814
- add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
815
 
816
  prompt_embeds = prompt_embeds.to(device)
817
  add_text_embeds = add_text_embeds.to(device)
818
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
819
 
 
 
 
 
 
 
820
  # 8. Denoising loop
821
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
822
 
823
- # 7.1 Apply denoising_end
824
- if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
 
 
 
 
 
825
  discrete_timestep_cutoff = int(
826
  round(
827
  self.scheduler.config.num_train_timesteps
828
- - (denoising_end * self.scheduler.config.num_train_timesteps)
829
  )
830
  )
831
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
832
  timesteps = timesteps[:num_inference_steps]
833
 
 
 
 
 
 
 
 
 
 
834
  latents = latents.half()
835
  prompt_embeds = prompt_embeds.half()
836
  with self.progress_bar(total=num_inference_steps) as progress_bar:
837
  for i, t in enumerate(timesteps):
838
- latents = self.update_loss(latents, i, t, prompt_embeds, cross_attention_kwargs, add_text_embeds, add_time_ids)
839
 
840
  # expand the latents if we are doing classifier free guidance
841
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
842
 
843
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
844
 
845
  # predict the noise residual
846
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
 
 
847
  noise_pred = self.unet(
848
  latent_model_input,
849
  t,
850
  encoder_hidden_states=prompt_embeds,
851
- cross_attention_kwargs=cross_attention_kwargs,
 
852
  added_cond_kwargs=added_cond_kwargs,
853
  return_dict=False,
854
  )[0]
855
 
856
  # perform guidance
857
- if do_classifier_free_guidance:
858
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
859
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
860
 
861
- if do_classifier_free_guidance and guidance_rescale > 0.0:
862
  # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
863
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
864
 
865
  # compute the previous noisy sample x_t -> x_t-1
866
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
  # call the callback, if provided
869
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
870
  progress_bar.update()
871
  if callback is not None and i % callback_steps == 0:
872
- callback(i, t, latents)
 
873
 
874
- # make sure the VAE is in float32 mode, as it overflows in float16
875
- if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
876
- self.upcast_vae()
877
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
878
 
879
  if not output_type == "latent":
 
 
 
 
 
 
 
880
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
 
 
 
 
881
  else:
882
  image = latents
883
- return StableDiffusionXLPipelineOutput(images=image)
884
 
885
- # apply watermark if available
886
- if self.watermark is not None:
887
- image = self.watermark.apply_watermark(image)
 
888
 
889
- image = self.image_processor.postprocess(image, output_type=output_type)
890
 
891
- # Offload last model to CPU
892
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
893
- self.final_offload_hook.offload()
894
 
895
  if not return_dict:
896
  return (image,)
897
 
898
  return StableDiffusionXLPipelineOutput(images=image)
899
-
900
- # Overrride to properly handle the loading and unloading of the additional text encoder.
901
- def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
902
- # We could have accessed the unet config from `lora_state_dict()` too. We pass
903
- # it here explicitly to be able to tell that it's coming from an SDXL
904
- # pipeline.
905
- state_dict, network_alphas = self.lora_state_dict(
906
- pretrained_model_name_or_path_or_dict,
907
- unet_config=self.unet.config,
908
- **kwargs,
909
- )
910
- self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
911
-
912
- text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
913
- if len(text_encoder_state_dict) > 0:
914
- self.load_lora_into_text_encoder(
915
- text_encoder_state_dict,
916
- network_alphas=network_alphas,
917
- text_encoder=self.text_encoder,
918
- prefix="text_encoder",
919
- lora_scale=self.lora_scale,
920
- )
921
-
922
- text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
923
- if len(text_encoder_2_state_dict) > 0:
924
- self.load_lora_into_text_encoder(
925
- text_encoder_2_state_dict,
926
- network_alphas=network_alphas,
927
- text_encoder=self.text_encoder_2,
928
- prefix="text_encoder_2",
929
- lora_scale=self.lora_scale,
930
- )
931
-
932
- @classmethod
933
- def save_lora_weights(
934
- self,
935
- save_directory: Union[str, os.PathLike],
936
- unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
937
- text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
938
- text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
939
- is_main_process: bool = True,
940
- weight_name: str = None,
941
- save_function: Callable = None,
942
- safe_serialization: bool = True,
943
- ):
944
- state_dict = {}
945
-
946
- def pack_weights(layers, prefix):
947
- layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
948
- layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
949
- return layers_state_dict
950
-
951
- state_dict.update(pack_weights(unet_lora_layers, "unet"))
952
-
953
- if text_encoder_lora_layers and text_encoder_2_lora_layers:
954
- state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
955
- state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
956
-
957
- self.write_lora_layers(
958
- state_dict=state_dict,
959
- save_directory=save_directory,
960
- is_main_process=is_main_process,
961
- weight_name=weight_name,
962
- save_function=save_function,
963
- safe_serialization=safe_serialization,
964
- )
965
-
966
- def _remove_text_encoder_monkey_patch(self):
967
- self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
968
- self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
 
13
  # limitations under the License.
14
 
15
  import inspect
 
16
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
 
18
  import torch
19
+ from transformers import (
20
+ CLIPImageProcessor,
21
+ CLIPTextModel,
22
+ CLIPTextModelWithProjection,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ )
26
 
27
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
28
+ from diffusers.loaders import (
29
+ FromSingleFileMixin,
30
+ IPAdapterMixin,
31
+ StableDiffusionXLLoraLoaderMixin,
32
+ TextualInversionLoaderMixin,
33
+ )
34
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
35
  from diffusers.models.attention_processor import (
36
  AttnProcessor2_0,
 
38
  LoRAXFormersAttnProcessor,
39
  XFormersAttnProcessor,
40
  )
41
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
42
  from diffusers.schedulers import KarrasDiffusionSchedulers
43
  from diffusers.utils import (
44
+ USE_PEFT_BACKEND,
45
+ deprecate,
46
  is_invisible_watermark_available,
47
+ is_torch_xla_available,
48
  logging,
 
49
  replace_example_docstring,
50
+ scale_lora_layers,
51
+ unscale_lora_layers,
52
  )
53
+ from diffusers.utils.torch_utils import randn_tensor
54
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
55
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
56
 
57
 
58
  if is_invisible_watermark_available():
59
  from .watermark import StableDiffusionXLWatermarker
60
 
61
+ if is_torch_xla_available():
62
+ import torch_xla.core.xla_model as xm
63
+
64
+ XLA_AVAILABLE = True
65
+ else:
66
+ XLA_AVAILABLE = False
67
+
68
 
69
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
70
 
 
100
  return noise_cfg
101
 
102
 
103
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104
+ def retrieve_timesteps(
105
+ scheduler,
106
+ num_inference_steps: Optional[int] = None,
107
+ device: Optional[Union[str, torch.device]] = None,
108
+ timesteps: Optional[List[int]] = None,
109
+ **kwargs,
110
+ ):
111
+ """
112
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
113
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
114
+
115
+ Args:
116
+ scheduler (`SchedulerMixin`):
117
+ The scheduler to get timesteps from.
118
+ num_inference_steps (`int`):
119
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
120
+ `timesteps` must be `None`.
121
+ device (`str` or `torch.device`, *optional*):
122
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
123
+ timesteps (`List[int]`, *optional*):
124
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
125
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
126
+ must be `None`.
127
+
128
+ Returns:
129
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
130
+ second element is the number of inference steps.
131
+ """
132
+ if timesteps is not None:
133
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
134
+ if not accepts_timesteps:
135
+ raise ValueError(
136
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
137
+ f" timestep schedules. Please check whether you are using the correct scheduler."
138
+ )
139
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
140
+ timesteps = scheduler.timesteps
141
+ num_inference_steps = len(timesteps)
142
+ else:
143
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
144
+ timesteps = scheduler.timesteps
145
+ return timesteps, num_inference_steps
146
+
147
+
148
+ class StableDiffusionXLPipeline(
149
+ DiffusionPipeline,
150
+ FromSingleFileMixin,
151
+ StableDiffusionXLLoraLoaderMixin,
152
+ TextualInversionLoaderMixin,
153
+ IPAdapterMixin,
154
+ ):
155
  r"""
156
  Pipeline for text-to-image generation using Stable Diffusion XL.
157
 
 
159
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
160
 
161
  In addition the pipeline inherits the following loading methods:
162
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
163
  - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
164
 
165
  as well as the following saving methods:
166
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
167
 
168
  Args:
169
  vae ([`AutoencoderKL`]):
 
188
  scheduler ([`SchedulerMixin`]):
189
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
190
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
191
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
192
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
193
+ `stabilityai/stable-diffusion-xl-base-1-0`.
194
+ add_watermarker (`bool`, *optional*):
195
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
196
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
197
+ watermarker will be used.
198
  """
199
 
200
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
201
+ _optional_components = [
202
+ "tokenizer",
203
+ "tokenizer_2",
204
+ "text_encoder",
205
+ "text_encoder_2",
206
+ "image_encoder",
207
+ "feature_extractor",
208
+ ]
209
+ _callback_tensor_inputs = [
210
+ "latents",
211
+ "prompt_embeds",
212
+ "negative_prompt_embeds",
213
+ "add_text_embeds",
214
+ "add_time_ids",
215
+ "negative_pooled_prompt_embeds",
216
+ "negative_add_time_ids",
217
+ ]
218
+
219
  def __init__(
220
  self,
221
  vae: AutoencoderKL,
 
225
  tokenizer_2: CLIPTokenizer,
226
  unet: UNet2DConditionModel,
227
  scheduler: KarrasDiffusionSchedulers,
228
+ image_encoder: CLIPVisionModelWithProjection = None,
229
+ feature_extractor: CLIPImageProcessor = None,
230
  force_zeros_for_empty_prompt: bool = True,
231
  add_watermarker: Optional[bool] = None,
232
  ):
 
240
  tokenizer_2=tokenizer_2,
241
  unet=unet,
242
  scheduler=scheduler,
243
+ image_encoder=image_encoder,
244
+ feature_extractor=feature_extractor,
245
  )
246
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
247
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
248
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
249
+
250
  self.default_sample_size = self.unet.config.sample_size
251
 
252
  add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
 
289
  """
290
  self.vae.disable_tiling()
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def encode_prompt(
293
  self,
294
  prompt: str,
 
303
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
304
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
305
  lora_scale: Optional[float] = None,
306
+ clip_skip: Optional[int] = None,
307
  ):
308
  r"""
309
  Encodes the prompt into text encoder hidden states.
 
343
  input argument.
344
  lora_scale (`float`, *optional*):
345
  A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
346
+ clip_skip (`int`, *optional*):
347
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
348
+ the output of the pre-final layer will be used for computing the prompt embeddings.
349
  """
350
  device = device or self._execution_device
351
 
352
  # set lora scale so that monkey patched LoRA
353
  # function of text encoder can correctly access it
354
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
355
  self._lora_scale = lora_scale
356
 
357
+ # dynamically adjust the LoRA scale
358
+ if self.text_encoder is not None:
359
+ if not USE_PEFT_BACKEND:
360
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
361
+ else:
362
+ scale_lora_layers(self.text_encoder, lora_scale)
363
+
364
+ if self.text_encoder_2 is not None:
365
+ if not USE_PEFT_BACKEND:
366
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
367
+ else:
368
+ scale_lora_layers(self.text_encoder_2, lora_scale)
369
+
370
+ prompt = [prompt] if isinstance(prompt, str) else prompt
371
+
372
+ if prompt is not None:
373
  batch_size = len(prompt)
374
  else:
375
  batch_size = prompt_embeds.shape[0]
 
382
 
383
  if prompt_embeds is None:
384
  prompt_2 = prompt_2 or prompt
385
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
386
+
387
  # textual inversion: procecss multi-vector tokens if necessary
388
  prompt_embeds_list = []
389
  prompts = [prompt, prompt_2]
 
411
  f" {tokenizer.model_max_length} tokens: {removed_text}"
412
  )
413
 
414
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
 
 
 
415
 
416
  # We are only ALWAYS interested in the pooled output of the final text encoder
417
  pooled_prompt_embeds = prompt_embeds[0]
418
+ if clip_skip is None:
419
+ prompt_embeds = prompt_embeds.hidden_states[-2]
420
+ else:
421
+ # "2" because SDXL always indexes from the penultimate layer.
422
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  prompt_embeds_list.append(prompt_embeds)
425
 
 
434
  negative_prompt = negative_prompt or ""
435
  negative_prompt_2 = negative_prompt_2 or negative_prompt
436
 
437
+ # normalize str to list
438
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
439
+ negative_prompt_2 = (
440
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
441
+ )
442
+
443
  uncond_tokens: List[str]
444
  if prompt is not None and type(prompt) is not type(negative_prompt):
445
  raise TypeError(
446
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
447
  f" {type(prompt)}."
448
  )
 
 
449
  elif batch_size != len(negative_prompt):
450
  raise ValueError(
451
  f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
 
481
 
482
  negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
483
 
484
+ if self.text_encoder_2 is not None:
485
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
486
+ else:
487
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
488
+
489
  bs_embed, seq_len, _ = prompt_embeds.shape
490
  # duplicate text embeddings for each generation per prompt, using mps friendly method
491
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
 
494
  if do_classifier_free_guidance:
495
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
496
  seq_len = negative_prompt_embeds.shape[1]
497
+
498
+ if self.text_encoder_2 is not None:
499
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
500
+ else:
501
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
502
+
503
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
504
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
505
 
 
511
  bs_embed * num_images_per_prompt, -1
512
  )
513
 
514
+ if self.text_encoder is not None:
515
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
516
+ # Retrieve the original scale by scaling back the LoRA layers
517
+ unscale_lora_layers(self.text_encoder, lora_scale)
518
+
519
+ if self.text_encoder_2 is not None:
520
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
521
+ # Retrieve the original scale by scaling back the LoRA layers
522
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
523
+
524
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
525
 
526
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
527
+ def encode_image(self, image, device, num_images_per_prompt):
528
+ dtype = next(self.image_encoder.parameters()).dtype
529
+
530
+ if not isinstance(image, torch.Tensor):
531
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
532
+
533
+ image = image.to(device=device, dtype=dtype)
534
+ image_embeds = self.image_encoder(image).image_embeds
535
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
536
+
537
+ uncond_image_embeds = torch.zeros_like(image_embeds)
538
+ return image_embeds, uncond_image_embeds
539
+
540
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
541
  def prepare_extra_step_kwargs(self, generator, eta):
542
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
 
568
  negative_prompt_embeds=None,
569
  pooled_prompt_embeds=None,
570
  negative_pooled_prompt_embeds=None,
571
+ callback_on_step_end_tensor_inputs=None,
572
  ):
573
  if height % 8 != 0 or width % 8 != 0:
574
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
575
 
576
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
 
 
577
  raise ValueError(
578
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
579
  f" {type(callback_steps)}."
580
  )
581
 
582
+ if callback_on_step_end_tensor_inputs is not None and not all(
583
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
584
+ ):
585
+ raise ValueError(
586
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
587
+ )
588
+
589
  if prompt is not None and prompt_embeds is not None:
590
  raise ValueError(
591
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
 
652
  latents = latents * self.scheduler.init_noise_sigma
653
  return latents
654
 
655
+ def _get_add_time_ids(
656
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
657
+ ):
658
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
659
 
660
  passed_add_embed_dim = (
661
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
662
  )
663
  expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
664
 
 
690
  self.vae.decoder.conv_in.to(dtype)
691
  self.vae.decoder.mid_block.to(dtype)
692
 
693
+ def update_loss(self, latents, i, t, prompt_embeds, timestep_cond, add_text_embeds, add_time_ids):
694
  def forward_pass(latent_model_input):
695
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
696
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
 
698
  latent_model_input,
699
  t,
700
  encoder_hidden_states=prompt_embeds,
701
+ timestep_cond=timestep_cond,
702
+ cross_attention_kwargs=self.cross_attention_kwargs,
703
  added_cond_kwargs=added_cond_kwargs,
704
  return_dict=False,
705
  )
 
707
 
708
  return self.editor.update_loss(forward_pass, latents, i)
709
 
710
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
711
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
712
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
713
+
714
+ The suffixes after the scaling factors represent the stages where they are being applied.
715
+
716
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
717
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
718
+
719
+ Args:
720
+ s1 (`float`):
721
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
722
+ mitigate "oversmoothing effect" in the enhanced denoising process.
723
+ s2 (`float`):
724
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
725
+ mitigate "oversmoothing effect" in the enhanced denoising process.
726
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
727
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
728
+ """
729
+ if not hasattr(self, "unet"):
730
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
731
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
732
+
733
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
734
+ def disable_freeu(self):
735
+ """Disables the FreeU mechanism if enabled."""
736
+ self.unet.disable_freeu()
737
+
738
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
739
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
740
+ """
741
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
742
+
743
+ Args:
744
+ timesteps (`torch.Tensor`):
745
+ generate embedding vectors at these timesteps
746
+ embedding_dim (`int`, *optional*, defaults to 512):
747
+ dimension of the embeddings to generate
748
+ dtype:
749
+ data type of the generated embeddings
750
+
751
+ Returns:
752
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
753
+ """
754
+ assert len(w.shape) == 1
755
+ w = w * 1000.0
756
+
757
+ half_dim = embedding_dim // 2
758
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
759
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
760
+ emb = w.to(dtype)[:, None] * emb[None, :]
761
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
762
+ if embedding_dim % 2 == 1: # zero pad
763
+ emb = torch.nn.functional.pad(emb, (0, 1))
764
+ assert emb.shape == (w.shape[0], embedding_dim)
765
+ return emb
766
+
767
+ @property
768
+ def guidance_scale(self):
769
+ return self._guidance_scale
770
+
771
+ @property
772
+ def guidance_rescale(self):
773
+ return self._guidance_rescale
774
+
775
+ @property
776
+ def clip_skip(self):
777
+ return self._clip_skip
778
+
779
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
780
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
781
+ # corresponds to doing no classifier free guidance.
782
+ @property
783
+ def do_classifier_free_guidance(self):
784
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
785
+
786
+ @property
787
+ def cross_attention_kwargs(self):
788
+ return self._cross_attention_kwargs
789
+
790
+ @property
791
+ def denoising_end(self):
792
+ return self._denoising_end
793
+
794
+ @property
795
+ def num_timesteps(self):
796
+ return self._num_timesteps
797
+
798
  @torch.no_grad()
799
  @replace_example_docstring(EXAMPLE_DOC_STRING)
800
  def __call__(
 
804
  height: Optional[int] = None,
805
  width: Optional[int] = None,
806
  num_inference_steps: int = 50,
807
+ timesteps: List[int] = None,
808
  denoising_end: Optional[float] = None,
809
  guidance_scale: float = 5.0,
810
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
817
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
818
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
819
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
820
+ ip_adapter_image: Optional[PipelineImageInput] = None,
821
  output_type: Optional[str] = "pil",
822
  return_dict: bool = True,
 
 
823
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
824
  guidance_rescale: float = 0.0,
825
  original_size: Optional[Tuple[int, int]] = None,
826
  crops_coords_top_left: Tuple[int, int] = (0, 0),
827
  target_size: Optional[Tuple[int, int]] = None,
828
+ negative_original_size: Optional[Tuple[int, int]] = None,
829
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
830
+ negative_target_size: Optional[Tuple[int, int]] = None,
831
+ clip_skip: Optional[int] = None,
832
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
833
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
834
+ **kwargs,
835
  ):
836
  r"""
837
  Function invoked when calling the pipeline for generation.
 
844
  The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
845
  used in both text-encoders
846
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
847
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
848
+ Anything below 512 pixels won't work well for
849
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
850
+ and checkpoints that are not specifically fine-tuned on low resolutions.
851
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
852
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
853
+ Anything below 512 pixels won't work well for
854
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
855
+ and checkpoints that are not specifically fine-tuned on low resolutions.
856
  num_inference_steps (`int`, *optional*, defaults to 50):
857
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
858
  expense of slower inference.
859
+ timesteps (`List[int]`, *optional*):
860
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
861
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
862
+ passed will be used. Must be in descending order.
863
  denoising_end (`float`, *optional*):
864
  When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
865
  completed before it is intentionally prematurely terminated. As a result, the returned sample will
 
906
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
907
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
908
  input argument.
909
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
910
  output_type (`str`, *optional*, defaults to `"pil"`):
911
  The output format of the generate image. Choose between
912
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
913
  return_dict (`bool`, *optional*, defaults to `True`):
914
  Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
915
  of a plain tuple.
 
 
 
 
 
 
916
  cross_attention_kwargs (`dict`, *optional*):
917
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
918
  `self.processor` in
919
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
920
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
921
  Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
922
  Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
923
  [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
924
  Guidance rescale factor should fix overexposure when using zero terminal SNR.
925
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
926
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
927
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
928
  explained in section 2.2 of
929
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
930
  crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
 
934
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
935
  target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
936
  For most cases, `target_size` should be set to the desired height and width of the generated image. If
937
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
938
  section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
939
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
940
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
941
+ micro-conditioning as explained in section 2.2 of
942
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
943
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
944
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
945
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
946
+ micro-conditioning as explained in section 2.2 of
947
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
948
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
949
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
950
+ To negatively condition the generation process based on a target image resolution. It should be as same
951
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
952
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
953
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
954
+ callback_on_step_end (`Callable`, *optional*):
955
+ A function that calls at the end of each denoising steps during the inference. The function is called
956
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
957
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
958
+ `callback_on_step_end_tensor_inputs`.
959
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
960
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
961
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
962
+ `._callback_tensor_inputs` attribute of your pipeline class.
963
 
964
  Examples:
965
 
 
968
  [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
969
  `tuple`. When returning a tuple, the first element is a list with the generated images.
970
  """
971
+
972
+ callback = kwargs.pop("callback", None)
973
+ callback_steps = kwargs.pop("callback_steps", None)
974
+
975
+ if callback is not None:
976
+ deprecate(
977
+ "callback",
978
+ "1.0.0",
979
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
980
+ )
981
+ if callback_steps is not None:
982
+ deprecate(
983
+ "callback_steps",
984
+ "1.0.0",
985
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
986
+ )
987
+
988
  # 0. Default height and width to unet
989
  height = height or self.default_sample_size * self.vae_scale_factor
990
  width = width or self.default_sample_size * self.vae_scale_factor
 
1005
  negative_prompt_embeds,
1006
  pooled_prompt_embeds,
1007
  negative_pooled_prompt_embeds,
1008
+ callback_on_step_end_tensor_inputs,
1009
  )
1010
 
1011
+ self._guidance_scale = guidance_scale
1012
+ self._guidance_rescale = guidance_rescale
1013
+ self._clip_skip = clip_skip
1014
+ self._cross_attention_kwargs = cross_attention_kwargs
1015
+ self._denoising_end = denoising_end
1016
+
1017
  # 2. Define call parameters
1018
  if prompt is not None and isinstance(prompt, str):
1019
  batch_size = 1
 
1024
 
1025
  device = self._execution_device
1026
 
 
 
 
 
 
1027
  # 3. Encode input prompt
1028
+ lora_scale = (
1029
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1030
  )
1031
+
1032
  (
1033
  prompt_embeds,
1034
  negative_prompt_embeds,
 
1039
  prompt_2=prompt_2,
1040
  device=device,
1041
  num_images_per_prompt=num_images_per_prompt,
1042
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1043
  negative_prompt=negative_prompt,
1044
  negative_prompt_2=negative_prompt_2,
1045
  prompt_embeds=prompt_embeds,
1046
  negative_prompt_embeds=negative_prompt_embeds,
1047
  pooled_prompt_embeds=pooled_prompt_embeds,
1048
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1049
+ lora_scale=lora_scale,
1050
+ clip_skip=self.clip_skip,
1051
  )
1052
 
1053
  # 4. Prepare timesteps
1054
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
 
 
1055
 
1056
  # 5. Prepare latent variables
1057
  num_channels_latents = self.unet.config.in_channels
 
1071
 
1072
  # 7. Prepare added time ids & embeddings
1073
  add_text_embeds = pooled_prompt_embeds
1074
+ if self.text_encoder_2 is None:
1075
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1076
+ else:
1077
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1078
+
1079
  add_time_ids = self._get_add_time_ids(
1080
+ original_size,
1081
+ crops_coords_top_left,
1082
+ target_size,
1083
+ dtype=prompt_embeds.dtype,
1084
+ text_encoder_projection_dim=text_encoder_projection_dim,
1085
  )
1086
+ if negative_original_size is not None and negative_target_size is not None:
1087
+ negative_add_time_ids = self._get_add_time_ids(
1088
+ negative_original_size,
1089
+ negative_crops_coords_top_left,
1090
+ negative_target_size,
1091
+ dtype=prompt_embeds.dtype,
1092
+ text_encoder_projection_dim=text_encoder_projection_dim,
1093
+ )
1094
+ else:
1095
+ negative_add_time_ids = add_time_ids
1096
 
1097
+ if self.do_classifier_free_guidance:
1098
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1099
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1100
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1101
 
1102
  prompt_embeds = prompt_embeds.to(device)
1103
  add_text_embeds = add_text_embeds.to(device)
1104
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1105
 
1106
+ if ip_adapter_image is not None:
1107
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
1108
+ if self.do_classifier_free_guidance:
1109
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
1110
+ image_embeds = image_embeds.to(device)
1111
+
1112
  # 8. Denoising loop
1113
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1114
 
1115
+ # 8.1 Apply denoising_end
1116
+ if (
1117
+ self.denoising_end is not None
1118
+ and isinstance(self.denoising_end, float)
1119
+ and self.denoising_end > 0
1120
+ and self.denoising_end < 1
1121
+ ):
1122
  discrete_timestep_cutoff = int(
1123
  round(
1124
  self.scheduler.config.num_train_timesteps
1125
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1126
  )
1127
  )
1128
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1129
  timesteps = timesteps[:num_inference_steps]
1130
 
1131
+ # 9. Optionally get Guidance Scale Embedding
1132
+ timestep_cond = None
1133
+ if self.unet.config.time_cond_proj_dim is not None:
1134
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1135
+ timestep_cond = self.get_guidance_scale_embedding(
1136
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1137
+ ).to(device=device, dtype=latents.dtype)
1138
+
1139
+ self._num_timesteps = len(timesteps)
1140
  latents = latents.half()
1141
  prompt_embeds = prompt_embeds.half()
1142
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1143
  for i, t in enumerate(timesteps):
1144
+ latents = self.update_loss(latents, i, t, prompt_embeds, timestep_cond, add_text_embeds, add_time_ids)
1145
 
1146
  # expand the latents if we are doing classifier free guidance
1147
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1148
 
1149
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1150
 
1151
  # predict the noise residual
1152
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1153
+ if ip_adapter_image is not None:
1154
+ added_cond_kwargs["image_embeds"] = image_embeds
1155
  noise_pred = self.unet(
1156
  latent_model_input,
1157
  t,
1158
  encoder_hidden_states=prompt_embeds,
1159
+ timestep_cond=timestep_cond,
1160
+ cross_attention_kwargs=self.cross_attention_kwargs,
1161
  added_cond_kwargs=added_cond_kwargs,
1162
  return_dict=False,
1163
  )[0]
1164
 
1165
  # perform guidance
1166
+ if self.do_classifier_free_guidance:
1167
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1168
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1169
 
1170
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1171
  # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1172
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1173
 
1174
  # compute the previous noisy sample x_t -> x_t-1
1175
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1176
 
1177
+ if callback_on_step_end is not None:
1178
+ callback_kwargs = {}
1179
+ for k in callback_on_step_end_tensor_inputs:
1180
+ callback_kwargs[k] = locals()[k]
1181
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1182
+
1183
+ latents = callback_outputs.pop("latents", latents)
1184
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1185
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1186
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1187
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1188
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1189
+ )
1190
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1191
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1192
+
1193
  # call the callback, if provided
1194
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1195
  progress_bar.update()
1196
  if callback is not None and i % callback_steps == 0:
1197
+ step_idx = i // getattr(self.scheduler, "order", 1)
1198
+ callback(step_idx, t, latents)
1199
 
1200
+ if XLA_AVAILABLE:
1201
+ xm.mark_step()
 
 
1202
 
1203
  if not output_type == "latent":
1204
+ # make sure the VAE is in float32 mode, as it overflows in float16
1205
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1206
+
1207
+ if needs_upcasting:
1208
+ self.upcast_vae()
1209
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1210
+
1211
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1212
+
1213
+ # cast back to fp16 if needed
1214
+ if needs_upcasting:
1215
+ self.vae.to(dtype=torch.float16)
1216
  else:
1217
  image = latents
 
1218
 
1219
+ if not output_type == "latent":
1220
+ # apply watermark if available
1221
+ if self.watermark is not None:
1222
+ image = self.watermark.apply_watermark(image)
1223
 
1224
+ image = self.image_processor.postprocess(image, output_type=output_type)
1225
 
1226
+ # Offload all models
1227
+ self.maybe_free_model_hooks()
 
1228
 
1229
  if not return_dict:
1230
  return (image,)
1231
 
1232
  return StableDiffusionXLPipelineOutput(images=image)