import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image from PIL import Image import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import FluxTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput from diffusers.pipelines import FluxInpaintPipeline from diffusers.pipelines.flux.pipeline_flux_inpaint import calculate_shift, retrieve_latents, retrieve_timesteps class FluxTryonPipeline(FluxInpaintPipeline): @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype, target_width=-1, tryon=False): latent_image_ids = torch.zeros(height, width, 3) if target_width==-1: latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] else: latent_image_ids[:, target_width:, 0] = 1 # height keep as before latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] if tryon: latent_image_ids[:, target_width*2:, 0] = 2 # left latent_image_ids[:, :target_width, 2] = latent_image_ids[:, :target_width, 2] + torch.arange(target_width)[None, :] # right latent_image_ids[:, target_width:, 2] = latent_image_ids[:, target_width:, 2] + torch.arange(width-target_width)[None, :] else: latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape( latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) def prepare_latents( self, image, timestep, batch_size, num_channels_latents, height, width, target_width, tryon, dtype, device, generator, latents=None, ): if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) sp = 2 * (int(target_width) // (self.vae_scale_factor * 2))//2 # -1 latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype, sp, tryon) image = image.to(device=device, dtype=dtype) # image_latents = self._encode_vae_image(image=image, generator=generator) img_parts = [image[:,:,:,:target_width], image[:,:,:,target_width:]] image_latents = [self._encode_vae_image(image=img, generator=generator) for img in img_parts] image_latents = torch.cat(image_latents, dim=-1) if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) else: noise = latents.to(device) latents = noise noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, noise, image_latents, latent_image_ids def prepare_mask_latents( self, mask, masked_image, batch_size, num_channels_latents, num_images_per_prompt, height, width, dtype, device, generator, ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate(mask, size=(height, width), mode="nearest") mask = mask.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt masked_image = masked_image.to(device=device, dtype=dtype) if masked_image.shape[1] == 16: masked_image_latents = masked_image else: masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) masked_image_latents = self._pack_latents( masked_image_latents, batch_size, num_channels_latents, height, width, ) mask = self._pack_latents( mask.repeat(1, num_channels_latents, 1, 1), batch_size, num_channels_latents, height, width, ) return mask, masked_image_latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, masked_image_latents: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, target_width: Optional[int] = None, tryon: bool = False, padding_mask_crop: Optional[int] = None, strength: float = 0.6, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, image, mask_image, strength, height, width, output_type=output_type, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, padding_mask_crop=padding_mask_crop, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Preprocess mask and image if padding_mask_crop is not None: crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) resize_mode = "fill" else: crops_coords = None resize_mode = "default" original_image = image init_image = self.image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) # 3. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, pooled_prompt_embeds, text_ids, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 4.Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu, ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 num_channels_transformer = self.transformer.config.in_channels latents, noise, image_latents, latent_image_ids= self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt, num_channels_latents, height, width, target_width, tryon, prompt_embeds.dtype, device, generator, latents, ) mask_condition = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) if masked_image_latents is None: masked_image = init_image * (mask_condition < 0.5) else: masked_image = masked_image_latents mask, masked_image_latents = self.prepare_mask_latents( mask_condition, masked_image, batch_size, num_channels_latents, num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # for 64 channel transformer only. init_latents_proper = image_latents init_mask = mask latents = (1 - init_mask) * init_latents_proper + init_mask * latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] ''' # for 64 channel transformer only. init_latents_proper = image_latents init_mask = mask # NOTE: we just use clean latents # if i < len(timesteps) - 1: # noise_timestep = timesteps[i + 1] # init_latents_proper = self.scheduler.scale_noise( # init_latents_proper, torch.tensor([noise_timestep]), noise # ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents ''' if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # if XLA_AVAILABLE: # xm.mark_step() # latents = (1 - mask) * image_latents + mask * latents if output_type == "latent": image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = latents[:,:,:,:target_width//self.vae_scale_factor] latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents.to(device=self.vae.device, dtype=self.vae.dtype), return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return FluxPipelineOutput(images=image) # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents def flux_pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) return latents # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def flux_unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (vae_scale_factor * 2)) width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents # TODO: it is more reasonable to have target pe staring at 0 def prepare_latent_image_ids(height, width_tgt, height_spa, width_spa, height_sub, width_sub, device, dtype): assert width_spa==0 or width_tgt==width_spa latent_image_ids = torch.zeros(height, width_tgt, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] # y坐标 latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width_tgt, device=device)[None, :] # x坐标 cond_mark = 0 if width_spa>0: cond_mark += 1 condspa_image_ids = torch.zeros(height_spa, width_spa, 3, device=device, dtype=dtype) condspa_image_ids[..., 0] = cond_mark condspa_image_ids[..., 1] = condspa_image_ids[..., 1] + torch.arange(height_spa, device=device)[:, None] condspa_image_ids[..., 2] = condspa_image_ids[..., 2] + torch.arange(width_spa, device=device)[None, :] condspa_image_ids = condspa_image_ids.reshape(-1, condspa_image_ids.shape[-1]) if width_sub>0: cond_mark += 1 condsub_image_ids = torch.zeros(height_sub, width_sub, 3, device=device, dtype=dtype) condsub_image_ids[..., 0] = cond_mark condsub_image_ids[..., 1] = condsub_image_ids[..., 1] + torch.arange(height_sub, device=device)[:, None] condsub_image_ids[..., 2] = condsub_image_ids[..., 2] + torch.arange(width_sub, device=device)[None, :] + width_tgt condsub_image_ids = condsub_image_ids.reshape(-1, condsub_image_ids.shape[-1]) latent_image_ids = latent_image_ids.reshape(-1, latent_image_ids.shape[-1]) latent_image_ids = torch.cat([latent_image_ids, condspa_image_ids],dim=-2) if width_spa>0 else latent_image_ids latent_image_ids = torch.cat([latent_image_ids, condsub_image_ids],dim=-2) if width_sub>0 else latent_image_ids return latent_image_ids def crop_to_multiple_of_16(img): width, height = img.size # Calculate new dimensions that are multiples of 8 new_width = width - (width % 16) new_height = height - (height % 16) # Calculate crop box coordinates left = (width - new_width) // 2 top = (height - new_height) // 2 right = left + new_width bottom = top + new_height # Crop the image cropped_img = img.crop((left, top, right, bottom)) return cropped_img def resize_and_pad_to_size(image, target_width, target_height): # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Get original dimensions orig_width, orig_height = image.size # Calculate aspect ratios target_ratio = target_width / target_height orig_ratio = orig_width / orig_height # Calculate new dimensions while maintaining aspect ratio if orig_ratio > target_ratio: # Image is wider than target ratio - scale by width new_width = target_width new_height = int(new_width / orig_ratio) else: # Image is taller than target ratio - scale by height new_height = target_height new_width = int(new_height * orig_ratio) # Resize image resized_image = image.resize((new_width, new_height)) # Create white background image of target size padded_image = Image.new('RGB', (target_width, target_height), 'white') # Calculate padding to center the image left_padding = (target_width - new_width) // 2 top_padding = (target_height - new_height) // 2 # Paste resized image onto padded background padded_image.paste(resized_image, (left_padding, top_padding)) return padded_image, left_padding, top_padding, target_width - new_width - left_padding, target_height - new_height - top_padding def resize_by_height(image, height): if isinstance(image, np.ndarray): image = Image.fromarray(image) # image is a PIL image image = image.resize((int(image.width * height / image.height), height)) return crop_to_multiple_of_16(image)