import warnings warnings.filterwarnings("ignore") from diffusers import DiffusionPipeline, DDIMInverseScheduler, DDIMScheduler, AutoencoderKL import torch from typing import Optional from tqdm import tqdm from diffusers.models.attention_processor import Attention, AttnProcessor2_0 import torchvision import torch.nn as nn import torch.nn.functional as F import gc import gradio as gr import numpy as np import os import pickle import argparse from PIL import Image import requests import math import torch from safetensors.torch import load_file from huggingface_hub import hf_hub_download from diffusers import DiffusionPipeline import spaces from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info def save_state_to_file(state): filename = "state.pkl" with open(filename, "wb") as f: pickle.dump(state, f) return filename def load_state_from_file(filename): with open(filename, "rb") as f: state = pickle.load(f) return state guidance_scale_value = 7.5 num_inference_steps = 10 weights = {} res_list = [] foreground_mask = None heighest_resolution = -1 signal_value = 2.0 blur_value = None allowed_res_max = 1.0 def load_pipeline(): model_id = "runwayml/stable-diffusion-v1-5" vae_model_id = "runwayml/stable-diffusion-v1-5" vae_folder = "vae" guidance_scale_value = 7.5 resadapter_model_name = "resadapter_v2_sd1.5" res_range_min = 128 res_range_max = 1024 torch_dtype = torch.float16 # torch_dtype = torch.float16 pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) pipe.vae = AutoencoderKL.from_pretrained(vae_model_id, subfolder=vae_folder, torch_dtype=torch_dtype) pipe.load_lora_weights( hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="pytorch_lora_weights.safetensors"), adapter_name="res_adapter", ) # load lora weights pipe.set_adapters(["res_adapter"], adapter_weights=[1.0]) pipe.unet.load_state_dict( load_file(hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="diffusion_pytorch_model.safetensors")), strict=False, ) # load norm weights inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler") scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") return pipe, inverse_scheduler, scheduler def load_qwen(): # default: Load the model on the available device(s) vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto" ) # default processer processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") return vlm_model, processor def weight_population(layer_type, resolution, depth, value): # Check if layer_type exists, if not, create it if layer_type not in weights: weights[layer_type] = {} # Check if resolution exists under layer_type, if not, create it if resolution not in weights[layer_type]: weights[layer_type][resolution] = {} global heighest_resolution if resolution > heighest_resolution: heighest_resolution = resolution # Add/Modify the value at the specified depth (which can be a string) weights[layer_type][resolution][depth] = value def resize_image_with_aspect(image, res_range_min=128, res_range_max=1024): # Get the original width and height of the image width, height = image.size # Determine the scaling factor to maintain the aspect ratio scaling_factor = 1 if width < res_range_min or height < res_range_min: scaling_factor = max(res_range_min / width, res_range_min / height) elif width > res_range_max or height > res_range_max: scaling_factor = min(res_range_max / width, res_range_max / height) # Calculate the new dimensions new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) print(f'{new_width}-{new_height}') # Resize the image with the new dimensions while maintaining the aspect ratio resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) return resized_image @spaces.GPU() def reconstruct(input_img, caption): pipe, inverse_scheduler, scheduler= load_pipeline() pipe.to("cuda") global weights weights = {} prompt = caption img = input_img img = resize_image_with_aspect(img, res_range_min, res_range_max) transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) if torch_dtype == torch.float16: loaded_image = transform(img).half().to("cuda").unsqueeze(0) else: loaded_image = transform(img).to("cuda").unsqueeze(0) if loaded_image.shape[1] == 4: loaded_image = loaded_image[:,:3,:,:] with torch.no_grad(): encoded_image = pipe.vae.encode(loaded_image*2 - 1) real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample() # notice we disabled the CFG here by setting guidance scale as 1 guidance_scale = 1.0 inverse_scheduler.set_timesteps(num_inference_steps, device="cuda") timesteps = inverse_scheduler.timesteps latents = real_image_latents inversed_latents = [latents] def store_latent(pipe, step, timestep, callback_kwargs): latents = callback_kwargs["latents"] with torch.no_grad(): if step != num_inference_steps - 1: inversed_latents.append(latents) return callback_kwargs with torch.no_grad(): replace_attention_processor(pipe.unet, True) pipe.scheduler = inverse_scheduler latents = pipe(prompt=prompt, guidance_scale = guidance_scale, output_type="latent", return_dict=False, num_inference_steps=num_inference_steps, latents=latents, callback_on_step_end=store_latent, callback_on_step_end_tensor_inputs=["latents"],)[0] # initial state real_image_initial_latents = latents guidance_scale = guidance_scale_value scheduler.set_timesteps(num_inference_steps, device="cuda") timesteps = scheduler.timesteps def adjust_latent(pipe, step, timestep, callback_kwargs): with torch.no_grad(): callback_kwargs["latents"] = inversed_latents[len(timesteps) - 1 - step].detach() return callback_kwargs with torch.no_grad(): replace_attention_processor(pipe.unet, True) intermediate_values = real_image_initial_latents.clone() pipe.scheduler = scheduler intermediate_values = pipe(prompt=prompt, guidance_scale = guidance_scale, output_type="latent", return_dict=False, num_inference_steps=num_inference_steps, latents=intermediate_values, callback_on_step_end=adjust_latent, callback_on_step_end_tensor_inputs=["latents"],)[0] image = pipe.vae.decode(intermediate_values / pipe.vae.config.scaling_factor, return_dict=False)[0] image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu() image_np = (image_np / 2 + 0.5).clamp(0, 1).numpy() image_np = (image_np * 255).astype(np.uint8) update_scale(12) real_cpu = real_image_initial_latents.detach().cpu() inversed_cpu = [x.detach().cpu() for x in inversed_latents] return image_np, caption, 12, [ caption, real_cpu, inversed_cpu, weights ] class AttnReplaceProcessor(AttnProcessor2_0): def __init__(self, replace_all, layer_type, layer_count, blur_sigma=None): super().__init__() self.replace_all = replace_all self.layer_type = layer_type self.layer_count = layer_count self.weight_populated = False self.blur_sigma = blur_sigma def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, *args, **kwargs, ) -> torch.FloatTensor: dimension_squared = hidden_states.shape[1] is_cross = not encoder_hidden_states is None residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) height = width = math.isqrt(query.shape[2]) if self.replace_all: weight_value = weights[self.layer_type][dimension_squared][self.layer_count] ucond_attn_scores, attn_scores = query.chunk(2) attn_scores[1].copy_(weight_value * attn_scores[0] + (1.0 - weight_value) * attn_scores[1]) ucond_attn_scores[1].copy_(weight_value * ucond_attn_scores[0] + (1.0 - weight_value) * ucond_attn_scores[1]) ucond_attn_scores, attn_scores = key.chunk(2) attn_scores[1].copy_(weight_value * attn_scores[0] + (1.0 - weight_value) * attn_scores[1]) ucond_attn_scores[1].copy_(weight_value * ucond_attn_scores[0] + (1.0 - weight_value) * ucond_attn_scores[1]) else: weight_population(self.layer_type, dimension_squared, self.layer_count, 1.0) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def replace_attention_processor(unet, clear=False, blur_sigma=None): attention_count = 0 for name, module in unet.named_modules(): if "attn1" in name and "to" not in name: layer_type = name.split(".")[0].split("_")[0] attention_count += 1 if not clear: if layer_type == "down": module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma) elif layer_type == "mid": module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma) elif layer_type == "up": module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma) else: module.processor = AttnReplaceProcessor(False, layer_type, attention_count, blur_sigma=blur_sigma) @spaces.GPU() def apply_prompt(meta_data, new_prompt): pipe, _, scheduler = load_pipeline() pipe.to("cuda") caption, real_latents_cpu, inversed_latents_cpu, saved_weights = meta_data # overwrite the global so your processor can see it global weights weights = saved_weights # move everything onto CUDA (and match dtype if needed) device = next(pipe.unet.parameters()).device dtype = next(pipe.unet.parameters()).dtype real_latents = real_latents_cpu.to(device=device, dtype=dtype) inversed_latents = [x.to(device=device, dtype=dtype) for x in inversed_latents_cpu] # now all your latents live on CUDA, so the callback won't mix devices initial_latents = torch.cat([real_latents] * 2) negative_prompt = "" inference_steps = len(inversed_latents) guidance_scale = guidance_scale_value scheduler.set_timesteps(inference_steps, device="cuda") timesteps = scheduler.timesteps def adjust_latent(pipe, step, timestep, callback_kwargs): replace_attention_processor(pipe.unet) with torch.no_grad(): callback_kwargs["latents"][1] = callback_kwargs["latents"][1] + (inversed_latents[len(timesteps) - 1 - step].detach() - callback_kwargs["latents"][0]) callback_kwargs["latents"][0] = inversed_latents[len(timesteps) - 1 - step].detach() return callback_kwargs with torch.no_grad(): replace_attention_processor(pipe.unet) pipe.scheduler = scheduler latents = pipe(prompt=[caption, new_prompt], negative_prompt=[negative_prompt, negative_prompt], guidance_scale = guidance_scale, output_type="latent", return_dict=False, num_inference_steps=num_inference_steps, latents=initial_latents, callback_on_step_end=adjust_latent, callback_on_step_end_tensor_inputs=["latents"],)[0] replace_attention_processor(pipe.unet, True) image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0] image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu() image_np = (image_np / 2 + 0.5).clamp(0, 1).numpy() image_np = (image_np * 255).astype(np.uint8) return image_np @spaces.GPU() def choose_caption(input_image): vlm_model, processor = load_qwen() image = input_image.convert("RGB") # (b) Wrap into messages messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, { "type": "text", "text": ( "Describe this image in one sentence, starting with the main object, " "then its key features, followed by the background elements. " "Use a clear, concise style, e.g., 'a photo of a plastic bottle on some sand, beach background, sky background'." ), }, ], } ] # 1) Turn messages → vision_inputs + (maybe) video_inputs image_inputs, video_inputs = process_vision_info(messages) # 2) Build text prompt from messages text_prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # 3) Tokenize both text and vision inputs = processor( text=[text_prompt], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to("cuda") out_ids = vlm_model.generate(**inputs, max_new_tokens=32) orig_ids = [out_ids[i][len(inputs.input_ids[i]):] for i in range(len(out_ids))] caption = processor.batch_decode(orig_ids, skip_special_tokens=True)[0] return caption @spaces.GPU(duration=30) def on_image_change(filepath): pipe, inverse_scheduler, scheduler = load_pipeline() filename = os.path.splitext(os.path.basename(filepath))[0] if filename not in ["example1","example3","example4"]: return filepath, None, None, None, None, None # 1) load the raw state (might contain CUDA tensors!) caption, real_latents, inversed_latents, saved_weights = load_state_from_file( f"assets/{filename}-turbo.pkl" ) # 2) immediately move all tensors to CPU real_cpu = real_latents.detach().cpu() inversed_cpu= [x.detach().cpu() for x in inversed_latents] # 3) repack a truly CPU-only state cpu_meta_data = (caption, real_cpu, inversed_cpu, saved_weights) # 4) your existing logic can then run on GPU, but it will # consume only CPU tensors and return only CPU tensors _, _, _, global_weights = cpu_meta_data global weights; weights = global_weights num_inference_steps = 10 if filename == "example1": scale_value, new_prompt = 8, "a photo of a tree, summer, colourful" elif filename == "example3": scale_value, new_prompt = 6, "a realistic photo of a female warrior, ..." else: scale_value, new_prompt = 13, "a photo of plastic bottle on some sand, ..." update_scale(scale_value) img = apply_prompt(cpu_meta_data, new_prompt) # 5) return only CPU objects (img is still a numpy array) return filepath, img, cpu_meta_data, num_inference_steps, scale_value, scale_value def update_value(value, layer_type, resolution, depth): global weights weights[layer_type][resolution][depth] = value def update_step(value): global num_inference_steps num_inference_steps = value def adjust_ends(values, adjustment): # Forward loop to adjust the first valid element from the left for i in range(len(values)): if (adjustment > 0 and values[i + 1] == 1.0) or (adjustment < 0 and values[i] > 0.0): values[i] = values[i] + adjustment break # Backward loop to adjust the first valid element from the right for i in range(len(values)-1, -1, -1): if (adjustment > 0 and values[i - 1] == 1.0) or (adjustment < 0 and values[i] > 0.0): values[i] = values[i] + adjustment break return values max_scale_value = 16 def update_scale(scale): global weights value_count = 0 for outer_key, inner_dict in weights.items(): for inner_key, values in inner_dict.items(): for _, value in enumerate(values): value_count += 1 list_values = [1.0] * value_count for _ in range(scale, max_scale_value): adjust_ends(list_values, -0.5) value_index = 0 for outer_key, inner_dict in weights.items(): for inner_key, values in inner_dict.items(): for idx, value in enumerate(values): weights[outer_key][inner_key][value] = list_values[value_index] value_index += 1 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true", help="Enable sharing of the Gradio interface") args = parser.parse_args() num_inference_steps = 10 model_id = "stabilityai/stable-diffusion-xl-base-1.0" vae_model_id = "madebyollin/sdxl-vae-fp16-fix" vae_folder = "" guidance_scale_value = 7.5 resadapter_model_name = "resadapter_v2_sdxl" res_range_min = 256 res_range_max = 1536 torch_dtype = torch.float16 with gr.Blocks(analytics_enabled=False) as demo: gr.HTML( """
Logo

Out of Focus XL v1.0 Turbo

Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process


""" ) with gr.Row(): with gr.Column(): with gr.Row(): example_input = gr.Image(type="filepath", visible=False) image_input = gr.Image(type="pil", label="Upload Source Image") steps_slider = gr.Slider(minimum=5, maximum=50, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image") prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image") reconstruct_button = gr.Button("Reconstruct") with gr.Column(): with gr.Row(): reconstructed_image = gr.Image(type="pil", label="Reconstructed") invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False) interpolate_slider = gr.Slider(minimum=0, maximum=max_scale_value, step=1, value=max_scale_value, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image") new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or adding words at the end; swap words instead of adding or removing them for better results") with gr.Row(): apply_button = gr.Button("Generate Vision", variant="primary", interactive=False) with gr.Row(): show_case = gr.Examples( examples=[ ["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background", 13], ["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful", 8], [ "assets/example3.png", "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds", "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds", 6 , ], ], inputs=[example_input, prompt_input, new_prompt_input, interpolate_slider], label=None, ) meta_data = gr.State() example_input.change(fn=on_image_change, inputs=example_input, outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]).then(lambda: gr.update(interactive=True), outputs=apply_button).then( lambda: gr.update(interactive=True), outputs=new_prompt_input ) image_input.upload(fn=choose_caption, inputs=image_input, outputs=[prompt_input]) steps_slider.release(update_step, inputs=steps_slider) interpolate_slider.release(update_scale, inputs=interpolate_slider) value_trigger = True def triggered(): global value_trigger value_trigger = not value_trigger return value_trigger reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, interpolate_slider, meta_data]).then(lambda: gr.update(interactive=True), outputs=reconstruct_button).then(lambda: gr.update(interactive=True), outputs=new_prompt_input).then( lambda: gr.update(interactive=True), outputs=apply_button ) reconstruct_button.click(lambda: gr.update(interactive=False), outputs=reconstruct_button) reconstruct_button.click(lambda: gr.update(interactive=False), outputs=apply_button) apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image) demo.queue() demo.launch(share=args.share, inbrowser=True)