Spaces:
Running
on
Zero
Running
on
Zero
import warnings | |
from diffusers import StableDiffusionPipeline | |
import torch | |
import spaces | |
from typing import Optional | |
from tqdm import tqdm | |
import math | |
import torch.nn.functional as F | |
from diffusers import DDIMScheduler | |
import numpy as np | |
import gradio as gr | |
model_id = "stabilityai/stable-diffusion-2-1-base" | |
def gaussian_blur_2d(img, kernel_size, sigma): | |
height = img.shape[-1] | |
kernel_size = min(kernel_size, height - (height % 2 - 1)) | |
ksize_half = (kernel_size - 1) * 0.5 | |
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) | |
pdf = torch.exp(-0.5 * (x / sigma).pow(2)) | |
x_kernel = pdf / pdf.sum() | |
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) | |
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) | |
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) | |
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] | |
img = F.pad(img, padding, mode="reflect") | |
img = F.conv2d(img, kernel2d, groups=img.shape[-3]) | |
return img | |
blur_inf = 'β' | |
def contextual_forward(self, blur_sigma = 0.0): | |
self.blur_sigma = blur_sigma | |
def forward_modified( | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
dimension_squared = hidden_states.shape[1] | |
is_cross = not encoder_hidden_states is None | |
residual = hidden_states | |
if self.spatial_norm is not None: | |
hidden_states = self.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) | |
if self.group_norm is not None: | |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = self.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif self.norm_cross: | |
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) | |
key = self.to_k(encoder_hidden_states) | |
value = self.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // self.heads | |
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) | |
height = width = math.isqrt(query.shape[2]) | |
if not is_cross and (self.blur_sigma == blur_inf or self.blur_sigma > 0): | |
# based of empirical findings | |
# 8x8 and 16x16 | |
allowed_res = [16, 8] | |
if (dimension_squared == allowed_res[0] * allowed_res[0] or dimension_squared == allowed_res[1] * allowed_res[1]): | |
query_uncond, query_org, query_ptb = query.chunk(3) | |
query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//3, self.heads * head_dim, height, width) | |
if self.blur_sigma != blur_inf: | |
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 | |
query_ptb = gaussian_blur_2d(query_ptb, kernel_size, self.blur_sigma) | |
else: | |
query_ptb[:] = query_ptb.mean(dim=(-2, -1), keepdim=True) | |
query_ptb = query_ptb.view(batch_size//3, self.heads, head_dim, height * width).permute(0, 1, 3, 2) | |
query = torch.cat((query_uncond, query_org, query_ptb), dim=0) | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = self.to_out[0](hidden_states) | |
# dropout | |
hidden_states = self.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if self.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / self.rescale_output_factor | |
return hidden_states | |
return forward_modified | |
def apply_seg(unet, child = None, blur_sigma=1.0): | |
if child == None: | |
children = unet.named_children() | |
for child in children: | |
apply_seg(unet, child[1], blur_sigma=blur_sigma) | |
else: | |
if child.__class__.__name__ == 'Attention': | |
child.forward = contextual_forward(child, blur_sigma=blur_sigma) | |
elif hasattr(child, 'children'): | |
for sub_child in child.children(): | |
apply_seg(unet, sub_child, blur_sigma=blur_sigma) | |
def sample(prompt, cfg = 7.5, blur_sigma = blur_inf, seed=123, steps=50, signal_scale = 1.0): | |
pipe = StableDiffusionPipeline.from_pretrained(model_id) | |
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") | |
pipe = pipe.to("cuda") | |
guidance_scale = cfg | |
num_inference_steps = steps | |
scheduler.set_timesteps(num_inference_steps, device="cuda") | |
timesteps = scheduler.timesteps | |
shape = (1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size) | |
latents = torch.randn(shape, generator=torch.Generator(device="cuda").manual_seed(seed), device="cuda").to("cuda") | |
prompt_cond = pipe.encode_prompt(prompt=prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] | |
prompt_uncond = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] | |
prompt_embeds_combined = torch.cat([prompt_uncond, prompt_cond, prompt_cond]) | |
with torch.no_grad(): | |
apply_seg(pipe.unet, blur_sigma=blur_sigma) | |
for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"): | |
latent_model_input = torch.cat([latents] * 3) | |
noise_pred = pipe.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=prompt_embeds_combined, | |
cross_attention_kwargs=None, | |
return_dict=False, | |
)[0] | |
noise_pred_uncond, noise_pred_cond, noise_pred_cond_perturb = noise_pred.chunk(3) | |
# Classfier Free Guidance | |
noise_cfg = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
# Smoothed Energy Guidance | |
noise_pred = noise_cfg + signal_scale * (noise_pred_cond - noise_pred_cond_perturb) | |
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] | |
image_cpu = image.squeeze(0).float().permute(1, 2, 0).detach().cpu() | |
image_cpu = (image_cpu / 2 + 0.5).clamp(0, 1) | |
return (image_cpu, cfg, blur_sigma) | |
def gradio_generate(prompt: str, cfg: float, blur_sigma: float): | |
# run your sample() and get back a torch.Tensor in [0β1] | |
image_tensor, _, _ = sample(prompt, cfg=cfg, blur_sigma=blur_sigma, seed=123, steps=50, signal_scale=1.0) | |
# to HxWxC uint8 | |
image_np = (image_tensor.cpu().numpy() * 255).round().astype(np.uint8) | |
return image_np | |
examples = [ | |
"A realistic photo of a woman, cyberpunk outfit, neon lighting, wall background", | |
"a painting of a bouquet of flowers, likely pansies, arranged in a vase. The painting style appears to be impressionistic, characterized by visible brushstrokes and a focus on capturing the overall impression rather than fine detail", | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 580px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center;"> | |
<h1>StableEnergy</h1> | |
<p style="font-size:16px;">Smoothed Energy Guidance in Stable Diffusion 2.1 </p> | |
</div> | |
<br> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href="https://github.com/OutofAi/StableEnergy"> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a>   | |
<a href="https://x.com/alexandernasa" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=alexnasa"></a>   | |
<a href="https://x.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=OutofAi"></a> | |
</div> | |
""" | |
) | |
with gr.Column(elem_id="col-container"): | |
prompt_in = gr.Textbox( | |
label="Prompt", | |
placeholder="Type your prompt here...", | |
lines=2 | |
) | |
with gr.Row(): | |
image_out_1 = gr.Image(type="numpy", label="Blur Ο = 0") | |
image_out_2 = gr.Image(type="numpy", label="Blur Ο = 10") | |
with gr.Row(): | |
cfg_slider = gr.Slider( | |
minimum=1.0, maximum=7.5, value=3.0, step=0.1, | |
label="CFG Scale" | |
) | |
blur_slider = gr.Slider( | |
minimum=5.0, maximum=100.0, value=10.0, step=0.1, | |
label="Blur Ο" | |
) | |
generate_btn = gr.Button("Generate") | |
gr.Examples( | |
examples = examples, | |
inputs = [prompt_in] | |
) | |
# wire up | |
generate_btn.click( | |
fn=gradio_generate, | |
inputs=[prompt_in, cfg_slider, gr.State(0)], | |
outputs=image_out_1 | |
).then(fn=gradio_generate, | |
inputs=[prompt_in, cfg_slider, blur_slider], | |
outputs=image_out_2 | |
).then( | |
lambda b: gr.update(label=f"Blur Ο = {b}"), | |
inputs=[blur_slider], | |
outputs=[image_out_2] | |
) | |
demo.launch() | |