outoffocus / app.py
alexnasa's picture
Update app.py
0d52ff1 verified
import warnings
warnings.filterwarnings("ignore")
from diffusers import DiffusionPipeline, DDIMInverseScheduler, DDIMScheduler, AutoencoderKL
import torch
from typing import Optional
from tqdm import tqdm
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import gc
import gradio as gr
import numpy as np
import os
import pickle
import argparse
from PIL import Image
import requests
import math
import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline
import spaces
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
def save_state_to_file(state):
filename = "state.pkl"
with open(filename, "wb") as f:
pickle.dump(state, f)
return filename
def load_state_from_file(filename):
with open(filename, "rb") as f:
state = pickle.load(f)
return state
guidance_scale_value = 7.5
num_inference_steps = 10
weights = {}
res_list = []
foreground_mask = None
heighest_resolution = -1
signal_value = 2.0
blur_value = None
allowed_res_max = 1.0
def load_pipeline():
model_id = "runwayml/stable-diffusion-v1-5"
vae_model_id = "runwayml/stable-diffusion-v1-5"
vae_folder = "vae"
guidance_scale_value = 7.5
resadapter_model_name = "resadapter_v2_sd1.5"
res_range_min = 128
res_range_max = 1024
torch_dtype = torch.float16
# torch_dtype = torch.float16
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
pipe.vae = AutoencoderKL.from_pretrained(vae_model_id, subfolder=vae_folder, torch_dtype=torch_dtype)
pipe.load_lora_weights(
hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="pytorch_lora_weights.safetensors"),
adapter_name="res_adapter",
) # load lora weights
pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
pipe.unet.load_state_dict(
load_file(hf_hub_download(repo_id="jiaxiangc/res-adapter", subfolder=resadapter_model_name, filename="diffusion_pytorch_model.safetensors")),
strict=False,
) # load norm weights
inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
return pipe, inverse_scheduler, scheduler
def load_qwen():
# default: Load the model on the available device(s)
vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)
# default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
return vlm_model, processor
def weight_population(layer_type, resolution, depth, value):
# Check if layer_type exists, if not, create it
if layer_type not in weights:
weights[layer_type] = {}
# Check if resolution exists under layer_type, if not, create it
if resolution not in weights[layer_type]:
weights[layer_type][resolution] = {}
global heighest_resolution
if resolution > heighest_resolution:
heighest_resolution = resolution
# Add/Modify the value at the specified depth (which can be a string)
weights[layer_type][resolution][depth] = value
def resize_image_with_aspect(image, res_range_min=128, res_range_max=1024):
# Get the original width and height of the image
width, height = image.size
# Determine the scaling factor to maintain the aspect ratio
scaling_factor = 1
if width < res_range_min or height < res_range_min:
scaling_factor = max(res_range_min / width, res_range_min / height)
elif width > res_range_max or height > res_range_max:
scaling_factor = min(res_range_max / width, res_range_max / height)
# Calculate the new dimensions
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
print(f'{new_width}-{new_height}')
# Resize the image with the new dimensions while maintaining the aspect ratio
resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
return resized_image
@spaces.GPU()
def reconstruct(input_img, caption):
pipe, inverse_scheduler, scheduler= load_pipeline()
pipe.to("cuda")
global weights
weights = {}
prompt = caption
img = input_img
img = resize_image_with_aspect(img, res_range_min, res_range_max)
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
if torch_dtype == torch.float16:
loaded_image = transform(img).half().to("cuda").unsqueeze(0)
else:
loaded_image = transform(img).to("cuda").unsqueeze(0)
if loaded_image.shape[1] == 4:
loaded_image = loaded_image[:,:3,:,:]
with torch.no_grad():
encoded_image = pipe.vae.encode(loaded_image*2 - 1)
real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()
# notice we disabled the CFG here by setting guidance scale as 1
guidance_scale = 1.0
inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = inverse_scheduler.timesteps
latents = real_image_latents
inversed_latents = [latents]
def store_latent(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
with torch.no_grad():
if step != num_inference_steps - 1:
inversed_latents.append(latents)
return callback_kwargs
with torch.no_grad():
replace_attention_processor(pipe.unet, True)
pipe.scheduler = inverse_scheduler
latents = pipe(prompt=prompt,
guidance_scale = guidance_scale,
output_type="latent",
return_dict=False,
num_inference_steps=num_inference_steps,
latents=latents,
callback_on_step_end=store_latent,
callback_on_step_end_tensor_inputs=["latents"],)[0]
# initial state
real_image_initial_latents = latents
guidance_scale = guidance_scale_value
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps
def adjust_latent(pipe, step, timestep, callback_kwargs):
with torch.no_grad():
callback_kwargs["latents"] = inversed_latents[len(timesteps) - 1 - step].detach()
return callback_kwargs
with torch.no_grad():
replace_attention_processor(pipe.unet, True)
intermediate_values = real_image_initial_latents.clone()
pipe.scheduler = scheduler
intermediate_values = pipe(prompt=prompt,
guidance_scale = guidance_scale,
output_type="latent",
return_dict=False,
num_inference_steps=num_inference_steps,
latents=intermediate_values,
callback_on_step_end=adjust_latent,
callback_on_step_end_tensor_inputs=["latents"],)[0]
image = pipe.vae.decode(intermediate_values / pipe.vae.config.scaling_factor, return_dict=False)[0]
image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
image_np = (image_np / 2 + 0.5).clamp(0, 1).numpy()
image_np = (image_np * 255).astype(np.uint8)
update_scale(12)
real_cpu = real_image_initial_latents.detach().cpu()
inversed_cpu = [x.detach().cpu() for x in inversed_latents]
return image_np, caption, 12, [
caption,
real_cpu,
inversed_cpu,
weights
]
class AttnReplaceProcessor(AttnProcessor2_0):
def __init__(self, replace_all, layer_type, layer_count, blur_sigma=None):
super().__init__()
self.replace_all = replace_all
self.layer_type = layer_type
self.layer_count = layer_count
self.weight_populated = False
self.blur_sigma = blur_sigma
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
dimension_squared = hidden_states.shape[1]
is_cross = not encoder_hidden_states is None
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
height = width = math.isqrt(query.shape[2])
if self.replace_all:
weight_value = weights[self.layer_type][dimension_squared][self.layer_count]
ucond_attn_scores, attn_scores = query.chunk(2)
attn_scores[1].copy_(weight_value * attn_scores[0] + (1.0 - weight_value) * attn_scores[1])
ucond_attn_scores[1].copy_(weight_value * ucond_attn_scores[0] + (1.0 - weight_value) * ucond_attn_scores[1])
ucond_attn_scores, attn_scores = key.chunk(2)
attn_scores[1].copy_(weight_value * attn_scores[0] + (1.0 - weight_value) * attn_scores[1])
ucond_attn_scores[1].copy_(weight_value * ucond_attn_scores[0] + (1.0 - weight_value) * ucond_attn_scores[1])
else:
weight_population(self.layer_type, dimension_squared, self.layer_count, 1.0)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def replace_attention_processor(unet, clear=False, blur_sigma=None):
attention_count = 0
for name, module in unet.named_modules():
if "attn1" in name and "to" not in name:
layer_type = name.split(".")[0].split("_")[0]
attention_count += 1
if not clear:
if layer_type == "down":
module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
elif layer_type == "mid":
module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
elif layer_type == "up":
module.processor = AttnReplaceProcessor(True, layer_type, attention_count, blur_sigma=blur_sigma)
else:
module.processor = AttnReplaceProcessor(False, layer_type, attention_count, blur_sigma=blur_sigma)
@spaces.GPU()
def apply_prompt(meta_data, new_prompt):
pipe, _, scheduler = load_pipeline()
pipe.to("cuda")
caption, real_latents_cpu, inversed_latents_cpu, saved_weights = meta_data
# overwrite the global so your processor can see it
global weights
weights = saved_weights
# move everything onto CUDA (and match dtype if needed)
device = next(pipe.unet.parameters()).device
dtype = next(pipe.unet.parameters()).dtype
real_latents = real_latents_cpu.to(device=device, dtype=dtype)
inversed_latents = [x.to(device=device, dtype=dtype) for x in inversed_latents_cpu]
# now all your latents live on CUDA, so the callback won't mix devices
initial_latents = torch.cat([real_latents] * 2)
negative_prompt = ""
inference_steps = len(inversed_latents)
guidance_scale = guidance_scale_value
scheduler.set_timesteps(inference_steps, device="cuda")
timesteps = scheduler.timesteps
def adjust_latent(pipe, step, timestep, callback_kwargs):
replace_attention_processor(pipe.unet)
with torch.no_grad():
callback_kwargs["latents"][1] = callback_kwargs["latents"][1] + (inversed_latents[len(timesteps) - 1 - step].detach() - callback_kwargs["latents"][0])
callback_kwargs["latents"][0] = inversed_latents[len(timesteps) - 1 - step].detach()
return callback_kwargs
with torch.no_grad():
replace_attention_processor(pipe.unet)
pipe.scheduler = scheduler
latents = pipe(prompt=[caption, new_prompt],
negative_prompt=[negative_prompt, negative_prompt],
guidance_scale = guidance_scale,
output_type="latent",
return_dict=False,
num_inference_steps=num_inference_steps,
latents=initial_latents,
callback_on_step_end=adjust_latent,
callback_on_step_end_tensor_inputs=["latents"],)[0]
replace_attention_processor(pipe.unet, True)
image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
image_np = (image_np / 2 + 0.5).clamp(0, 1).numpy()
image_np = (image_np * 255).astype(np.uint8)
return image_np
@spaces.GPU()
def choose_caption(input_image):
vlm_model, processor = load_qwen()
image = input_image.convert("RGB")
# (b) Wrap into messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": (
"Describe this image in one sentence, starting with the main object, "
"then its key features, followed by the background elements. "
"Use a clear, concise style, e.g., 'a photo of a plastic bottle on some sand, beach background, sky background'."
),
},
],
}
]
# 1) Turn messages → vision_inputs + (maybe) video_inputs
image_inputs, video_inputs = process_vision_info(messages)
# 2) Build text prompt from messages
text_prompt = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# 3) Tokenize both text and vision
inputs = processor(
text=[text_prompt],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
out_ids = vlm_model.generate(**inputs, max_new_tokens=32)
orig_ids = [out_ids[i][len(inputs.input_ids[i]):] for i in range(len(out_ids))]
caption = processor.batch_decode(orig_ids, skip_special_tokens=True)[0]
return caption
@spaces.GPU(duration=30)
def on_image_change(filepath):
pipe, inverse_scheduler, scheduler = load_pipeline()
filename = os.path.splitext(os.path.basename(filepath))[0]
if filename not in ["example1","example3","example4"]:
return filepath, None, None, None, None, None
# 1) load the raw state (might contain CUDA tensors!)
caption, real_latents, inversed_latents, saved_weights = load_state_from_file(
f"assets/{filename}-turbo.pkl"
)
# 2) immediately move all tensors to CPU
real_cpu = real_latents.detach().cpu()
inversed_cpu= [x.detach().cpu() for x in inversed_latents]
# 3) repack a truly CPU-only state
cpu_meta_data = (caption, real_cpu, inversed_cpu, saved_weights)
# 4) your existing logic can then run on GPU, but it will
# consume only CPU tensors and return only CPU tensors
_, _, _, global_weights = cpu_meta_data
global weights; weights = global_weights
num_inference_steps = 10
if filename == "example1":
scale_value, new_prompt = 8, "a photo of a tree, summer, colourful"
elif filename == "example3":
scale_value, new_prompt = 6, "a realistic photo of a female warrior, ..."
else:
scale_value, new_prompt = 13, "a photo of plastic bottle on some sand, ..."
update_scale(scale_value)
img = apply_prompt(cpu_meta_data, new_prompt)
# 5) return only CPU objects (img is still a numpy array)
return filepath, img, cpu_meta_data, num_inference_steps, scale_value, scale_value
def update_value(value, layer_type, resolution, depth):
global weights
weights[layer_type][resolution][depth] = value
def update_step(value):
global num_inference_steps
num_inference_steps = value
def adjust_ends(values, adjustment):
# Forward loop to adjust the first valid element from the left
for i in range(len(values)):
if (adjustment > 0 and values[i + 1] == 1.0) or (adjustment < 0 and values[i] > 0.0):
values[i] = values[i] + adjustment
break
# Backward loop to adjust the first valid element from the right
for i in range(len(values)-1, -1, -1):
if (adjustment > 0 and values[i - 1] == 1.0) or (adjustment < 0 and values[i] > 0.0):
values[i] = values[i] + adjustment
break
return values
max_scale_value = 16
def update_scale(scale):
global weights
value_count = 0
for outer_key, inner_dict in weights.items():
for inner_key, values in inner_dict.items():
for _, value in enumerate(values):
value_count += 1
list_values = [1.0] * value_count
for _ in range(scale, max_scale_value):
adjust_ends(list_values, -0.5)
value_index = 0
for outer_key, inner_dict in weights.items():
for inner_key, values in inner_dict.items():
for idx, value in enumerate(values):
weights[outer_key][inner_key][value] = list_values[value_index]
value_index += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", help="Enable sharing of the Gradio interface")
args = parser.parse_args()
num_inference_steps = 10
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
vae_model_id = "madebyollin/sdxl-vae-fp16-fix"
vae_folder = ""
guidance_scale_value = 7.5
resadapter_model_name = "resadapter_v2_sdxl"
res_range_min = 256
res_range_max = 1536
torch_dtype = torch.float16
with gr.Blocks(analytics_enabled=False) as demo:
gr.HTML(
"""
<div style="text-align: center;">
<div style="display: flex; justify-content: center;">
<img src="https://github.com/user-attachments/assets/9b92a2cd-4c1f-4de2-87f0-09053fe129ff" alt="Logo">
</div>
<h1>Out of Focus XL v1.0 Turbo</h1>
<p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/OutofAi/OutofFocus">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a> &ensp;
<a href="https://twitter.com/alexandernasa" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=alexnasa"></a> &ensp;
<a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=OutofAi"></a>
</div>
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
example_input = gr.Image(type="filepath", visible=False)
image_input = gr.Image(type="pil", label="Upload Source Image")
steps_slider = gr.Slider(minimum=5, maximum=50, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image")
prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image")
reconstruct_button = gr.Button("Reconstruct")
with gr.Column():
with gr.Row():
reconstructed_image = gr.Image(type="pil", label="Reconstructed")
invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False)
interpolate_slider = gr.Slider(minimum=0, maximum=max_scale_value, step=1, value=max_scale_value, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image")
new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or adding words at the end; swap words instead of adding or removing them for better results")
with gr.Row():
apply_button = gr.Button("Generate Vision", variant="primary", interactive=False)
with gr.Row():
show_case = gr.Examples(
examples=[
["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background", 13],
["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful", 8],
[
"assets/example3.png",
"a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds",
"a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds",
6 ,
],
],
inputs=[example_input, prompt_input, new_prompt_input, interpolate_slider],
label=None,
)
meta_data = gr.State()
example_input.change(fn=on_image_change, inputs=example_input, outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]).then(lambda: gr.update(interactive=True), outputs=apply_button).then(
lambda: gr.update(interactive=True), outputs=new_prompt_input
)
image_input.upload(fn=choose_caption, inputs=image_input, outputs=[prompt_input])
steps_slider.release(update_step, inputs=steps_slider)
interpolate_slider.release(update_scale, inputs=interpolate_slider)
value_trigger = True
def triggered():
global value_trigger
value_trigger = not value_trigger
return value_trigger
reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, interpolate_slider, meta_data]).then(lambda: gr.update(interactive=True), outputs=reconstruct_button).then(lambda: gr.update(interactive=True), outputs=new_prompt_input).then(
lambda: gr.update(interactive=True), outputs=apply_button
)
reconstruct_button.click(lambda: gr.update(interactive=False), outputs=reconstruct_button)
reconstruct_button.click(lambda: gr.update(interactive=False), outputs=apply_button)
apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image)
demo.queue()
demo.launch(share=args.share, inbrowser=True)