StableEnergy / app.py
alexnasa's picture
Update app.py
8a6f42f verified
import warnings
from diffusers import StableDiffusionPipeline
import torch
import spaces
from typing import Optional
from tqdm import tqdm
import math
import torch.nn.functional as F
from diffusers import DDIMScheduler
import numpy as np
import gradio as gr
model_id = "stabilityai/stable-diffusion-2-1-base"
def gaussian_blur_2d(img, kernel_size, sigma):
height = img.shape[-1]
kernel_size = min(kernel_size, height - (height % 2 - 1))
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
blur_inf = '∞'
def contextual_forward(self, blur_sigma = 0.0):
self.blur_sigma = blur_sigma
def forward_modified(
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
dimension_squared = hidden_states.shape[1]
is_cross = not encoder_hidden_states is None
residual = hidden_states
if self.spatial_norm is not None:
hidden_states = self.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
height = width = math.isqrt(query.shape[2])
if not is_cross and (self.blur_sigma == blur_inf or self.blur_sigma > 0):
# based of empirical findings
# 8x8 and 16x16
allowed_res = [16, 8]
if (dimension_squared == allowed_res[0] * allowed_res[0] or dimension_squared == allowed_res[1] * allowed_res[1]):
query_uncond, query_org, query_ptb = query.chunk(3)
query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//3, self.heads * head_dim, height, width)
if self.blur_sigma != blur_inf:
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
query_ptb = gaussian_blur_2d(query_ptb, kernel_size, self.blur_sigma)
else:
query_ptb[:] = query_ptb.mean(dim=(-2, -1), keepdim=True)
query_ptb = query_ptb.view(batch_size//3, self.heads, head_dim, height * width).permute(0, 1, 3, 2)
query = torch.cat((query_uncond, query_org, query_ptb), dim=0)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if self.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
return forward_modified
def apply_seg(unet, child = None, blur_sigma=1.0):
if child == None:
children = unet.named_children()
for child in children:
apply_seg(unet, child[1], blur_sigma=blur_sigma)
else:
if child.__class__.__name__ == 'Attention':
child.forward = contextual_forward(child, blur_sigma=blur_sigma)
elif hasattr(child, 'children'):
for sub_child in child.children():
apply_seg(unet, sub_child, blur_sigma=blur_sigma)
@spaces.GPU
def sample(prompt, cfg = 7.5, blur_sigma = blur_inf, seed=123, steps=50, signal_scale = 1.0):
pipe = StableDiffusionPipeline.from_pretrained(model_id)
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = pipe.to("cuda")
guidance_scale = cfg
num_inference_steps = steps
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps
shape = (1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size)
latents = torch.randn(shape, generator=torch.Generator(device="cuda").manual_seed(seed), device="cuda").to("cuda")
prompt_cond = pipe.encode_prompt(prompt=prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
prompt_uncond = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
prompt_embeds_combined = torch.cat([prompt_uncond, prompt_cond, prompt_cond])
with torch.no_grad():
apply_seg(pipe.unet, blur_sigma=blur_sigma)
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
latent_model_input = torch.cat([latents] * 3)
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds_combined,
cross_attention_kwargs=None,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_cond, noise_pred_cond_perturb = noise_pred.chunk(3)
# Classfier Free Guidance
noise_cfg = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# Smoothed Energy Guidance
noise_pred = noise_cfg + signal_scale * (noise_pred_cond - noise_pred_cond_perturb)
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image_cpu = image.squeeze(0).float().permute(1, 2, 0).detach().cpu()
image_cpu = (image_cpu / 2 + 0.5).clamp(0, 1)
return (image_cpu, cfg, blur_sigma)
def gradio_generate(prompt: str, cfg: float, blur_sigma: float):
# run your sample() and get back a torch.Tensor in [0–1]
image_tensor, _, _ = sample(prompt, cfg=cfg, blur_sigma=blur_sigma, seed=123, steps=50, signal_scale=1.0)
# to HxWxC uint8
image_np = (image_tensor.cpu().numpy() * 255).round().astype(np.uint8)
return image_np
examples = [
"A realistic photo of a woman, cyberpunk outfit, neon lighting, wall background",
"a painting of a bouquet of flowers, likely pansies, arranged in a vase. The painting style appears to be impressionistic, characterized by visible brushstrokes and a focus on capturing the overall impression rather than fine detail",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
]
css="""
#col-container {
margin: 0 auto;
max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center;">
<h1>StableEnergy</h1>
<p style="font-size:16px;">Smoothed Energy Guidance in Stable Diffusion 2.1 </p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/OutofAi/StableEnergy">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a> &ensp;
<a href="https://x.com/alexandernasa" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=alexnasa"></a> &ensp;
<a href="https://x.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=OutofAi"></a>
</div>
"""
)
with gr.Column(elem_id="col-container"):
prompt_in = gr.Textbox(
label="Prompt",
placeholder="Type your prompt here...",
lines=2
)
with gr.Row():
image_out_1 = gr.Image(type="numpy", label="Blur Οƒ = 0")
image_out_2 = gr.Image(type="numpy", label="Blur Οƒ = 10")
with gr.Row():
cfg_slider = gr.Slider(
minimum=1.0, maximum=7.5, value=3.0, step=0.1,
label="CFG Scale"
)
blur_slider = gr.Slider(
minimum=5.0, maximum=100.0, value=10.0, step=0.1,
label="Blur Οƒ"
)
generate_btn = gr.Button("Generate")
gr.Examples(
examples = examples,
inputs = [prompt_in]
)
# wire up
generate_btn.click(
fn=gradio_generate,
inputs=[prompt_in, cfg_slider, gr.State(0)],
outputs=image_out_1
).then(fn=gradio_generate,
inputs=[prompt_in, cfg_slider, blur_slider],
outputs=image_out_2
).then(
lambda b: gr.update(label=f"Blur Οƒ = {b}"),
inputs=[blur_slider],
outputs=[image_out_2]
)
demo.launch()