import torch from diffusers import DiffusionPipeline, DDPMScheduler, StableDiffusionPipeline from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.image_processor import VaeImageProcessor from huggingface_hub import PyTorchModelHubMixin from transformers import CLIPTextModel, CLIPTextModelWithProjection from diffusers.models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) class CombinedStableDiffusionXL( DiffusionPipeline, PyTorchModelHubMixin ): """ A Stable Diffusion model wrapper that provides functionality for text-to-image synthesis, noise scheduling, latent space manipulation, and image decoding. """ def __init__( self, original_unet: torch.nn.Module, fine_tuned_unet: torch.nn.Module, scheduler: DDPMScheduler, vae: torch.nn.Module, tokenizer: CLIPTextModel, tokenizer_2: CLIPTextModel, text_encoder: CLIPTextModelWithProjection, text_encoder_2: CLIPTextModelWithProjection, ) -> None: super().__init__() self.register_modules( tokenizer=tokenizer, tokenizer_2=tokenizer_2, text_encoder=text_encoder, text_encoder_2=text_encoder_2, original_unet=original_unet, fine_tuned_unet=fine_tuned_unet, scheduler=scheduler, vae=vae, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor ) self.resolution = 1024 def _get_negative_prompts( self, batch_size: int ) -> tuple[torch.Tensor, torch.Tensor]: inputs_ids_1 = self.tokenizer( [""] * batch_size, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids input_ids_2 = self.tokenizer_2( [""] * batch_size, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids return inputs_ids_1, input_ids_2 def _get_encoder_hidden_states( self, tokenized_prompts_1: torch.Tensor, tokenized_prompts_2: torch.Tensor, do_classifier_free_guidance: bool = False ) -> torch.Tensor: text_input_ids_list = [ tokenized_prompts_1, tokenized_prompts_2 ] batch_size = text_input_ids_list[0].size(0) if do_classifier_free_guidance: negative_prompts = [ embed.to(text_input_ids_list[0].device) for embed in self._get_negative_prompts(batch_size) ] text_input_ids_list = [ torch.cat( [ negative_prompt, text_input, ] ) for text_input, negative_prompt in zip( text_input_ids_list, negative_prompts ) ] prompt_embeds_list = [] text_encoders = [self.text_encoder, self.text_encoder_2] for text_encoder, text_input_ids in zip(text_encoders, text_input_ids_list): prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False, ) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.cat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return prompt_embeds, pooled_prompt_embeds def _get_unet_prediction( self, latent_model_input: torch.Tensor, timestep: int, encoder_hidden_states: torch.Tensor, ) -> torch.Tensor: """ Return unet noise prediction Args: latent_model_input (torch.Tensor): Unet latents input timestep (int): noise scheduler timestep encoder_hidden_states (tuple[torch.Tensor, torch.Tensor]): Text encoder hidden states Returns: torch.Tensor: noise prediction """ unet = self.original_unet if self._use_original_unet else self.fine_tuned_unet prompt_embeds, pooled_prompt_embeds = encoder_hidden_states target_size = torch.tensor( [ [self.resolution, self.resolution] for _ in range(latent_model_input.size(0)) ], device=latent_model_input.device, dtype=torch.float32, ) add_time_ids = torch.cat( [target_size, torch.zeros_like(target_size), target_size], dim=1 ) unet_added_conditions = { "time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds, } return unet( latent_model_input, timestep, encoder_hidden_states=prompt_embeds, added_cond_kwargs=unet_added_conditions, ).sample def get_noise_prediction( self, latents: torch.Tensor, timestep_index: int, encoder_hidden_states: torch.Tensor, do_classifier_free_guidance: bool = False, detach_main_path: bool = False, ): """ Return noise prediction Args: latents (torch.Tensor): Image latents timestep_index (int): noise scheduler timestep index encoder_hidden_states (torch.Tensor): Text encoder hidden states do_classifier_free_guidance (bool) Whether to do classifier free guidance detach_main_path (bool): Detach gradient Returns: torch.Tensor: noise prediction """ timestep = self.scheduler.timesteps[timestep_index] latent_model_input = self.scheduler.scale_model_input( sample=torch.cat([latents] * 2) if do_classifier_free_guidance else latents, timestep=timestep, ) noise_pred = self._get_unet_prediction( latent_model_input=latent_model_input, timestep=timestep, encoder_hidden_states=encoder_hidden_states, ) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) if detach_main_path: noise_pred_text = noise_pred_text.detach() noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond ) return noise_pred def sample_next_latents( self, latents: torch.Tensor, timestep_index: int, noise_pred: torch.Tensor, return_pred_original: bool = False, ) -> torch.Tensor: """ Return next latents prediction Args: latents (torch.Tensor): Image latents timestep_index (int): noise scheduler timestep index noise_pred (torch.Tensor): noise prediction return_pred_original (bool) Whether to sample original sample Returns: torch.Tensor: latent prediction """ timestep = self.scheduler.timesteps[timestep_index] sample = self.scheduler.step( model_output=noise_pred, timestep=timestep, sample=latents ) return ( sample.pred_original_sample if return_pred_original else sample.prev_sample ) def predict_next_latents( self, latents: torch.Tensor, timestep_index: int, encoder_hidden_states: torch.Tensor, return_pred_original: bool = False, do_classifier_free_guidance: bool = False, detach_main_path: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Predicts the next latent states during the diffusion process. Args: latents (torch.Tensor): Current latent states. timestep_index (int): Index of the current timestep. encoder_hidden_states (torch.Tensor): Encoder hidden states from the text encoder. return_pred_original (bool): Whether to return the predicted original sample. do_classifier_free_guidance (bool) Whether to do classifier free guidance detach_main_path (bool): Detach gradient Returns: tuple: Next latents and predicted noise tensor. """ noise_pred = self.get_noise_prediction( latents=latents, timestep_index=timestep_index, encoder_hidden_states=encoder_hidden_states, do_classifier_free_guidance=do_classifier_free_guidance, detach_main_path=detach_main_path, ) latents = self.sample_next_latents( latents=latents, noise_pred=noise_pred, timestep_index=timestep_index, return_pred_original=return_pred_original, ) return latents, noise_pred def get_latents(self, batch_size: int, device: torch.device) -> torch.Tensor: latent_resolution = int(self.resolution) // self.vae_scale_factor return torch.randn( ( batch_size, self.original_unet.config.in_channels, latent_resolution, latent_resolution, ), device=device, ) def do_k_diffusion_steps( self, start_timestep_index: int, end_timestep_index: int, latents: torch.Tensor, encoder_hidden_states: torch.Tensor, return_pred_original: bool = False, do_classifier_free_guidance: bool = False, detach_main_path: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Performs multiple diffusion steps between specified timesteps. Args: start_timestep_index (int): Starting timestep index. end_timestep_index (int): Ending timestep index. latents (torch.Tensor): Initial latents. encoder_hidden_states (torch.Tensor): Encoder hidden states. return_pred_original (bool): Whether to return the predicted original sample. do_classifier_free_guidance (bool) Whether to do classifier free guidance detach_main_path (bool): Detach gradient Returns: tuple: Resulting latents and encoder hidden states. """ assert start_timestep_index <= end_timestep_index for timestep_index in range(start_timestep_index, end_timestep_index - 1): latents, _ = self.predict_next_latents( latents=latents, timestep_index=timestep_index, encoder_hidden_states=encoder_hidden_states, return_pred_original=False, do_classifier_free_guidance=do_classifier_free_guidance, detach_main_path=detach_main_path, ) res, _ = self.predict_next_latents( latents=latents, timestep_index=end_timestep_index - 1, encoder_hidden_states=encoder_hidden_states, return_pred_original=return_pred_original, do_classifier_free_guidance=do_classifier_free_guidance, ) return res, encoder_hidden_states def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) @torch.no_grad() def __call__( self, prompt: str | list[str], num_inference_steps=40, original_unet_steps=35, resolution=1024, guidance_scale=5, output_type: str = "pil", return_dict: bool = True, ): self.guidance_scale = guidance_scale self.resolution = resolution batch_size = 1 if isinstance(prompt, str) else len(prompt) tokenized_prompts_1 = self.tokenizer( prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids tokenized_prompts_2 = self.tokenizer_2( prompt, max_length=self.tokenizer_2.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ).input_ids original_encoder_hidden_states = self._get_encoder_hidden_states( tokenized_prompts_1=tokenized_prompts_1, tokenized_prompts_2=tokenized_prompts_2, do_classifier_free_guidance=True ) fine_tuned_encoder_hidden_states = self._get_encoder_hidden_states( tokenized_prompts_1=tokenized_prompts_1, tokenized_prompts_2=tokenized_prompts_2, do_classifier_free_guidance=False ) latent_resolution = int(resolution) // self.vae_scale_factor latents = torch.randn( ( batch_size, self.original_unet.config.in_channels, latent_resolution, latent_resolution, ), device=self.device, ) self.scheduler.set_timesteps( num_inference_steps, device=self.device ) self._use_original_unet = True latents, _ = self.do_k_diffusion_steps( start_timestep_index=0, end_timestep_index=original_unet_steps, latents=latents, encoder_hidden_states=original_encoder_hidden_states, return_pred_original=False, do_classifier_free_guidance=True, ) self._use_original_unet = False latents, _ = self.do_k_diffusion_steps( start_timestep_index=original_unet_steps, end_timestep_index=num_inference_steps, latents=latents, encoder_hidden_states=fine_tuned_encoder_hidden_states, return_pred_original=False, do_classifier_free_guidance=False, ) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents).sample # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": image = self.image_processor.postprocess( image, output_type=output_type, do_denormalize=[True] * image.shape[0] ) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image)