Spaces:
Runtime error
Runtime error
| import gc | |
| import os | |
| from typing import List | |
| import contextlib | |
| import torch.multiprocessing as mp | |
| from dataclasses import dataclass, field | |
| from collections import defaultdict | |
| import random | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import json | |
| import torch | |
| from peft import PeftModel | |
| import torch.nn.functional as F | |
| import accelerate | |
| import diffusers | |
| from diffusers import FluxPipeline | |
| from diffusers.utils.torch_utils import is_compiled_module | |
| import transformers | |
| from tqdm import tqdm | |
| from peft import LoraConfig, set_peft_model_state_dict | |
| from peft.utils import get_peft_model_state_dict | |
| from dreamfuse.models.dreamfuse_flux.transformer import ( | |
| FluxTransformer2DModel, | |
| FluxTransformerBlock, | |
| FluxSingleTransformerBlock, | |
| ) | |
| from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( | |
| FlowMatchEulerDiscreteScheduler, | |
| ) | |
| from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps | |
| from dreamfuse.trains.utils.inference_utils import ( | |
| compute_text_embeddings, | |
| prepare_latents, | |
| _unpack_latents, | |
| _pack_latents, | |
| _prepare_image_ids, | |
| encode_images_cond, | |
| get_mask_affine, | |
| warp_affine_tensor | |
| ) | |
| def seed_everything(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| class InferenceConfig: | |
| # Model paths | |
| flux_model_id: str = 'black-forest-labs/FLUX.1-dev' | |
| lora_id: str = '' | |
| model_choice: str = 'dev' | |
| # Model configs | |
| lora_rank: int = 16 | |
| max_sequence_length: int = 256 | |
| guidance_scale: float = 3.5 | |
| num_inference_steps: int = 28 | |
| mask_ids: int = 16 | |
| mask_in_chans: int = 128 | |
| mask_out_chans: int = 3072 | |
| inference_scale = 1024 | |
| # Training configs | |
| gradient_checkpointing: bool = False | |
| mix_attention_double: bool = True | |
| mix_attention_single: bool = True | |
| # Image processing | |
| image_ids_offset: List[int] = field(default_factory=lambda: [0, 0, 0]) | |
| image_tags: List[int] = field(default_factory=lambda: [0, 1, 2]) | |
| context_tags: List[int] = None | |
| # Runtime configs | |
| device: str = "cuda:0" # if torch.cuda.is_available() else "cpu" | |
| dtype: torch.dtype = torch.bfloat16 | |
| seed: int = 1234 | |
| debug: bool = True | |
| # I/O configs | |
| valid_output_dir: str = "./inference_output" | |
| valid_roots: List[str] = field(default_factory=lambda: [ | |
| "./", | |
| ]) | |
| valid_jsons: List[str] = field(default_factory=lambda: [ | |
| "./examples/data_dreamfuse.json", | |
| ]) | |
| ref_prompts: str = "" | |
| truecfg: bool = False | |
| text_strength: int = 5 | |
| # multi gpu | |
| sub_idx:int = 0 | |
| total_num:int = 1 | |
| def adjust_fg_to_bg(image: Image.Image, mask: Image.Image, target_size: tuple) -> tuple[Image.Image, Image.Image]: | |
| width, height = image.size | |
| target_w, target_h = target_size | |
| scale = min(target_w / width, target_h / height) | |
| if scale < 1: | |
| new_w = int(width * scale) | |
| new_h = int(height * scale) | |
| image = image.resize((new_w, new_h)) | |
| mask = mask.resize((new_w, new_h)) | |
| width, height = new_w, new_h | |
| pad_w = target_w - width | |
| pad_h = target_h - height | |
| padding = ( | |
| pad_w // 2, # left | |
| pad_h // 2, # top | |
| (pad_w + 1) // 2, # right | |
| (pad_h + 1) // 2 # bottom | |
| ) | |
| image = ImageOps.expand(image, border=padding, fill=(255, 255, 255)) | |
| mask = ImageOps.expand(mask, border=padding, fill=0) | |
| return image, mask | |
| def find_nearest_bucket_size(input_width, input_height, mode="x64", bucket_size=1024): | |
| """ | |
| Finds the nearest bucket size for the given input size. | |
| """ | |
| buckets = { | |
| 512: [[ 256, 768 ], [ 320, 768 ], [ 320, 704 ], [ 384, 640 ], [ 448, 576 ], [ 512, 512 ], [ 576, 448 ], [ 640, 384 ], [ 704, 320 ], [ 768, 320 ], [ 768, 256 ]], | |
| 768: [[ 384, 1152 ], [ 480, 1152 ], [ 480, 1056 ], [ 576, 960 ], [ 672, 864 ], [ 768, 768 ], [ 864, 672 ], [ 960, 576 ], [ 1056, 480 ], [ 1152, 480 ], [ 1152, 384 ]], | |
| 1024: [[ 512, 1536 ], [ 640, 1536 ], [ 640, 1408 ], [ 768, 1280 ], [ 896, 1152 ], [ 1024, 1024 ], [ 1152, 896 ], [ 1280, 768 ], [ 1408, 640 ], [ 1536, 640 ], [ 1536, 512 ]] | |
| } | |
| buckets = buckets[bucket_size] | |
| aspect_ratios = [w / h for (w, h) in buckets] | |
| assert mode in ["x64", "x8"] | |
| if mode == "x64": | |
| asp = input_width / input_height | |
| diff = [abs(ar - asp) for ar in aspect_ratios] | |
| bucket_id = int(np.argmin(diff)) | |
| gen_width, gen_height = buckets[bucket_id] | |
| elif mode == "x8": | |
| max_pixels = 1024 * 1024 | |
| ratio = (max_pixels / (input_width * input_height)) ** (0.5) | |
| gen_width, gen_height = round(input_width * ratio), round(input_height * ratio) | |
| gen_width = gen_width - gen_width % 8 | |
| gen_height = gen_height - gen_height % 8 | |
| else: | |
| raise NotImplementedError | |
| return (gen_width, gen_height) | |
| def make_image_grid(images, rows, cols, size=None): | |
| assert len(images) == rows * cols | |
| if size is not None: | |
| images = [img.resize((size[0], size[1])) for img in images] | |
| w, h = images[0].size | |
| grid = Image.new("RGB", size=(cols * w, rows * h)) | |
| for i, img in enumerate(images): | |
| grid.paste(img.convert("RGB"), box=(i % cols * w, i // cols * h)) | |
| return grid | |
| class DreamFuseInference: | |
| def __init__(self, config: InferenceConfig): | |
| self.config = config | |
| self.device = torch.device(config.device) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| seed_everything(config.seed) | |
| self._init_models() | |
| def _init_models(self): | |
| # Initialize tokenizers | |
| self.tokenizer_one = transformers.CLIPTokenizer.from_pretrained( | |
| self.config.flux_model_id, subfolder="tokenizer" | |
| ) | |
| self.tokenizer_two = transformers.T5TokenizerFast.from_pretrained( | |
| self.config.flux_model_id, subfolder="tokenizer_2" | |
| ) | |
| # Initialize text encoders | |
| self.text_encoder_one = transformers.CLIPTextModel.from_pretrained( | |
| self.config.flux_model_id, subfolder="text_encoder" | |
| ).to(device=self.device, dtype=self.config.dtype) | |
| self.text_encoder_two = transformers.T5EncoderModel.from_pretrained( | |
| self.config.flux_model_id, subfolder="text_encoder_2" | |
| ).to(device=self.device, dtype=self.config.dtype) | |
| # Initialize VAE | |
| self.vae = diffusers.AutoencoderKL.from_pretrained( | |
| self.config.flux_model_id, subfolder="vae" | |
| ).to(device=self.device, dtype=self.config.dtype) | |
| # Initialize denoising model | |
| self.denoise_model = FluxTransformer2DModel.from_pretrained( | |
| self.config.flux_model_id, subfolder="transformer" | |
| ).to(device=self.device, dtype=self.config.dtype) | |
| if self.config.image_tags is not None or self.config.context_tags is not None: | |
| num_image_tag_embeddings = max(self.config.image_tags) + 1 if self.config.image_tags is not None else 0 | |
| num_context_tag_embeddings = max(self.config.context_tags) + 1 if self.config.context_tags is not None else 0 | |
| self.denoise_model.set_tag_embeddings( | |
| num_image_tag_embeddings=num_image_tag_embeddings, | |
| num_context_tag_embeddings=num_context_tag_embeddings, | |
| ) | |
| # Add LoRA | |
| self.denoise_model = PeftModel.from_pretrained( | |
| self.denoise_model, | |
| self.config.lora_id, | |
| adapter_weights=[1.0], | |
| device_map={"": self.device} | |
| ) | |
| # Initialize scheduler | |
| self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| self.config.flux_model_id, subfolder="scheduler" | |
| ) | |
| # Set models to eval mode | |
| for model in [self.text_encoder_one, self.text_encoder_two, self.vae, self.denoise_model]: | |
| model.eval() | |
| model.requires_grad_(False) | |
| def _compute_text_embeddings(self, prompt): | |
| return compute_text_embeddings( | |
| self.config, | |
| prompt, | |
| [self.text_encoder_one, self.text_encoder_two], | |
| [self.tokenizer_one, self.tokenizer_two], | |
| self.device | |
| ) | |
| def resize_to_fit_within(self, reference_image, target_image): | |
| ref_width, ref_height = reference_image.size | |
| target_width, target_height = target_image.size | |
| scale_width = ref_width / target_width | |
| scale_height = ref_height / target_height | |
| scale = min(scale_width, scale_height) | |
| new_width = int(target_width * scale) | |
| new_height = int(target_height * scale) | |
| resized_image = target_image.resize((new_width, new_height), Image.LANCZOS) | |
| return resized_image | |
| def pad_or_crop(self, img, target_size, fill_color=(255, 255, 255)): | |
| iw, ih = img.size | |
| tw, th = target_size | |
| # 计算裁剪区域:若原图大于目标尺寸,则裁剪出中间部分;否则全部保留 | |
| left = (iw - tw) // 2 if iw >= tw else 0 | |
| top = (ih - th) // 2 if ih >= th else 0 | |
| cropped = img.crop((left, top, left + min(iw, tw), top + min(ih, th))) | |
| # 新建目标尺寸的图像,并将裁剪后的图像居中粘贴 | |
| new_img = Image.new(img.mode, target_size, fill_color) | |
| offset = ((tw - cropped.width) // 2, (th - cropped.height) // 2) | |
| new_img.paste(cropped, offset) | |
| return new_img | |
| def transform_foreground_original(self, original_fg, original_bg, transformation_info, canvas_size=400): | |
| drag_left = float(transformation_info.get("drag_left", 0)) | |
| drag_top = float(transformation_info.get("drag_top", 0)) | |
| scale_ratio = float(transformation_info.get("scale_ratio", 1)) | |
| data_orig_width = float(transformation_info.get("data_original_width", canvas_size)) | |
| data_orig_height = float(transformation_info.get("data_original_height", canvas_size)) | |
| drag_width = float(transformation_info.get("drag_width", 0)) | |
| drag_height = float(transformation_info.get("drag_height", 0)) | |
| scale_ori_fg = canvas_size / max(original_fg.width, original_fg.height) | |
| scale_ori_bg = canvas_size / max(original_bg.width, original_bg.height) | |
| # 计算未缩放状态下(预览中)的默认居中位置(前景图未拖拽时的理想位置) | |
| default_left = (canvas_size - data_orig_width) / 2.0 | |
| default_top = (canvas_size - data_orig_height) / 2.0 | |
| # 在未缩放状态下,计算实际拖拽产生的偏移(单位:像素,在预览尺寸下计算) | |
| offset_preview_x = drag_left - default_left | |
| offset_preview_y = drag_top - default_top | |
| offset_ori_x = offset_preview_x / scale_ori_fg | |
| offset_ori_y = offset_preview_y / scale_ori_fg | |
| new_width = int(original_fg.width * scale_ratio) | |
| new_height = int(original_fg.height * scale_ratio) | |
| scale_fg = original_fg.resize((new_width, new_height)) | |
| output = Image.new("RGBA", (original_fg.width, original_fg.height), (255, 255, 255, 0)) | |
| output.paste(scale_fg, (int(offset_ori_x), int(offset_ori_y))) | |
| new_width_fgbg = original_fg.width * scale_ori_fg / scale_ori_bg | |
| new_height_fgbg = original_fg.height * scale_ori_fg / scale_ori_bg | |
| scale_fgbg = output.resize((int(new_width_fgbg), int(new_height_fgbg))) | |
| final_output = Image.new("RGBA", (original_bg.width, original_bg.height), (255, 255, 255, 0)) | |
| scale_fgbg = self.pad_or_crop(scale_fgbg, (original_bg.width, original_bg.height), (255, 255, 255, 0)) | |
| final_output.paste(scale_fgbg, (0, 0)) | |
| fit_fg = self.resize_to_fit_within(original_bg, original_fg) | |
| fit_fg = self.pad_or_crop(fit_fg, original_bg.size, (255, 255, 255, 0)) | |
| return final_output, fit_fg | |
| def gradio_generate(self, background_img, foreground_img, transformation_info, seed, prompt, enable_gui, cfg=3.5, size_select="1024", text_strength=1, truecfg=False): | |
| try: | |
| trans = json.loads(transformation_info) | |
| except: | |
| trans = {} | |
| size_select = int(size_select) | |
| if size_select == 1024 and prompt != "": text_strength = 5 | |
| if size_select == 768 and prompt != "": text_strength = 3 | |
| r, g, b, ori_a = foreground_img.split() | |
| fg_img_scale, fg_img = self.transform_foreground_original(foreground_img, background_img, trans) | |
| new_r, new_g, new_b, new_a = fg_img_scale.split() | |
| foreground_img_scale = Image.merge("RGB", (new_r, new_g, new_b)) | |
| r, g, b, ori_a = fg_img.split() | |
| foreground_img = Image.merge("RGB", (r, g, b)) | |
| foreground_img_save = foreground_img.copy() | |
| ori_a = ori_a.convert("L") | |
| new_a = new_a.convert("L") | |
| foreground_img.paste((255, 255, 255), mask=ImageOps.invert(ori_a)) | |
| images = self.model_generate(foreground_img.copy(), background_img.copy(), | |
| ori_a, new_a, | |
| enable_mask_affine=enable_gui, | |
| prompt=prompt, | |
| offset_cond=[0, 1, 0] if not enable_gui else None, | |
| seed=seed, | |
| cfg=cfg, | |
| size_select=size_select, | |
| text_strength=text_strength, | |
| truecfg=truecfg) | |
| images = Image.fromarray(images[0], "RGB") | |
| images = images.resize(background_img.size) | |
| # images.thumbnail((640, 640), Image.LANCZOS) | |
| return images | |
| def model_generate(self, fg_image, bg_image, ori_fg_mask, new_fg_mask, enable_mask_affine=True, prompt="", offset_cond=None, seed=None, cfg=3.5, size_select=1024, text_strength=1, truecfg=False): | |
| batch_size = 1 | |
| # Prepare images | |
| # adjust bg->fg size | |
| fg_image, ori_fg_mask = adjust_fg_to_bg(fg_image, ori_fg_mask, bg_image.size) | |
| bucket_size = find_nearest_bucket_size(bg_image.size[0], bg_image.size[1], bucket_size=size_select) | |
| fg_image = fg_image.resize(bucket_size) | |
| bg_image = bg_image.resize(bucket_size) | |
| mask_affine = None | |
| if enable_mask_affine: | |
| ori_fg_mask = ori_fg_mask.resize(bucket_size) | |
| new_fg_mask = new_fg_mask.resize(bucket_size) | |
| mask_affine = get_mask_affine(new_fg_mask, ori_fg_mask) | |
| # Get embeddings | |
| prompt_embeds, pooled_prompt_embeds, text_ids = self._compute_text_embeddings(prompt) | |
| prompt_embeds = prompt_embeds.repeat(1, text_strength, 1) | |
| text_ids = text_ids.repeat(text_strength, 1) | |
| # Prepare | |
| if self.config.model_choice == "dev": | |
| guidance = torch.full([1], cfg, device=self.device, dtype=torch.float32) | |
| guidance = guidance.expand(batch_size) | |
| else: | |
| guidance = None | |
| # Prepare generator | |
| if seed is None: | |
| seed = self.config.seed | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| # Prepare condition latents | |
| condition_image_latents = self._encode_images([fg_image, bg_image]) | |
| if offset_cond is None: | |
| offset_cond = self.config.image_ids_offset | |
| offset_cond = offset_cond[1:] | |
| cond_latent_image_ids = [] | |
| for offset_ in offset_cond: | |
| cond_latent_image_ids.append( | |
| self._prepare_image_ids( | |
| condition_image_latents.shape[2] // 2, | |
| condition_image_latents.shape[3] // 2, | |
| offset_w=offset_ * condition_image_latents.shape[3] // 2 | |
| ) | |
| ) | |
| if mask_affine is not None: | |
| affine_H, affine_W = condition_image_latents.shape[2] // 2, condition_image_latents.shape[3] // 2 | |
| scale_factor = 1 / 16 | |
| cond_latent_image_ids_fg = cond_latent_image_ids[0].reshape(affine_H, affine_W, 3).clone() | |
| # opt 1 | |
| cond_latent_image_ids[0] = warp_affine_tensor( | |
| cond_latent_image_ids_fg, mask_affine, output_size=(affine_H, affine_W), | |
| scale_factor=scale_factor, device=self.device, | |
| ) | |
| cond_latent_image_ids = torch.stack(cond_latent_image_ids) | |
| # Pack condition latents | |
| cond_image_latents = self._pack_latents(condition_image_latents) | |
| cond_input = { | |
| "image_latents": cond_image_latents, | |
| "image_ids": cond_latent_image_ids, | |
| } | |
| # Prepare initial latents | |
| width, height = bucket_size | |
| num_channels_latents = self.denoise_model.config.in_channels // 4 | |
| latents, latent_image_ids = self._prepare_latents( | |
| batch_size, num_channels_latents, height, width, generator | |
| ) | |
| # Setup timesteps | |
| sigmas = np.linspace(1.0, 1 / self.config.num_inference_steps, self.config.num_inference_steps) | |
| image_seq_len = latents.shape[1] | |
| mu = calculate_shift( | |
| image_seq_len, | |
| self.scheduler.config.base_image_seq_len, | |
| self.scheduler.config.max_image_seq_len, | |
| self.scheduler.config.base_shift, | |
| self.scheduler.config.max_shift, | |
| ) | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, | |
| self.config.num_inference_steps, | |
| self.device, | |
| sigmas=sigmas, | |
| mu=mu, | |
| ) | |
| # Denoising loop | |
| for i, t in enumerate(timesteps): | |
| timestep = t.expand(latents.shape[0]).to(latents.dtype) | |
| with torch.autocast(enabled=True, device_type="cuda", dtype=self.config.dtype): | |
| noise_pred = self.denoise_model( | |
| hidden_states=latents, | |
| cond_input=cond_input, | |
| timestep=timestep / 1000, | |
| guidance=guidance, | |
| pooled_projections=pooled_prompt_embeds, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=latent_image_ids, | |
| data_num_per_group=batch_size, | |
| image_tags=self.config.image_tags, | |
| context_tags=self.config.context_tags, | |
| max_sequence_length=self.config.max_sequence_length, | |
| mix_attention_double=self.config.mix_attention_double, | |
| mix_attention_single=self.config.mix_attention_single, | |
| joint_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| if truecfg and i >= 1: | |
| guidance_neg = torch.full([1], 1, device=self.device, dtype=torch.float32) | |
| guidance_neg = guidance_neg.expand(batch_size) | |
| noise_pred_neg = self.denoise_model( | |
| hidden_states=latents, | |
| cond_input=cond_input, | |
| timestep=timestep / 1000, | |
| guidance=guidance, | |
| pooled_projections=pooled_prompt_embeds, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=latent_image_ids, | |
| data_num_per_group=batch_size, | |
| image_tags=self.config.image_tags, | |
| context_tags=self.config.context_tags, | |
| max_sequence_length=self.config.max_sequence_length, | |
| mix_attention_double=self.config.mix_attention_double, | |
| mix_attention_single=self.config.mix_attention_single, | |
| joint_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| noise_pred = noise_pred_neg + 5 * (noise_pred - noise_pred_neg) | |
| # Compute previous noisy sample | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| # Decode latents | |
| latents = self._unpack_latents(latents, height, width) | |
| latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
| images = self.vae.decode(latents, return_dict=False)[0] | |
| # Post-process images | |
| images = images.add(1).mul(127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() | |
| return images | |
| def _encode_images(self, images): | |
| return encode_images_cond(self.vae, [images], self.device) | |
| def _prepare_image_ids(self, h, w, offset_w=0): | |
| return _prepare_image_ids(h, w, offset_w=offset_w).to(self.device) | |
| def _pack_latents(self, latents): | |
| b, c, h, w = latents.shape | |
| return _pack_latents(latents, b, c, h, w) | |
| def _unpack_latents(self, latents, height, width): | |
| vae_scale = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
| return _unpack_latents(latents, height, width, vae_scale) | |
| def _prepare_latents(self, batch_size, num_channels_latents, height, width, generator): | |
| vae_scale = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
| latents, latent_image_ids = prepare_latents( | |
| batch_size=batch_size, | |
| num_channels_latents=num_channels_latents, | |
| vae_downsample_factor=vae_scale, | |
| height=height, | |
| width=width, | |
| dtype=self.config.dtype, | |
| device=self.device, | |
| generator=generator, | |
| offset=None | |
| ) | |
| return latents, latent_image_ids | |