import warnings from diffusers import StableDiffusionPipeline import torch import spaces from typing import Optional from tqdm import tqdm import math import torch.nn.functional as F from diffusers import DDIMScheduler import numpy as np import gradio as gr model_id = "stabilityai/stable-diffusion-2-1-base" def gaussian_blur_2d(img, kernel_size, sigma): height = img.shape[-1] kernel_size = min(kernel_size, height - (height % 2 - 1)) ksize_half = (kernel_size - 1) * 0.5 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) x_kernel = pdf / pdf.sum() x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] img = F.pad(img, padding, mode="reflect") img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return img blur_inf = '∞' def contextual_forward(self, blur_sigma = 0.0): self.blur_sigma = blur_sigma def forward_modified( hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: dimension_squared = hidden_states.shape[1] is_cross = not encoder_hidden_states is None residual = hidden_states if self.spatial_norm is not None: hidden_states = self.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // self.heads query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) height = width = math.isqrt(query.shape[2]) if not is_cross and (self.blur_sigma == blur_inf or self.blur_sigma > 0): # based of empirical findings # 8x8 and 16x16 allowed_res = [16, 8] if (dimension_squared == allowed_res[0] * allowed_res[0] or dimension_squared == allowed_res[1] * allowed_res[1]): query_uncond, query_org, query_ptb = query.chunk(3) query_ptb = query_ptb.permute(0, 1, 3, 2).view(batch_size//3, self.heads * head_dim, height, width) if self.blur_sigma != blur_inf: kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2 query_ptb = gaussian_blur_2d(query_ptb, kernel_size, self.blur_sigma) else: query_ptb[:] = query_ptb.mean(dim=(-2, -1), keepdim=True) query_ptb = query_ptb.view(batch_size//3, self.heads, head_dim, height * width).permute(0, 1, 3, 2) query = torch.cat((query_uncond, query_org, query_ptb), dim=0) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if self.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / self.rescale_output_factor return hidden_states return forward_modified def apply_seg(unet, child = None, blur_sigma=1.0): if child == None: children = unet.named_children() for child in children: apply_seg(unet, child[1], blur_sigma=blur_sigma) else: if child.__class__.__name__ == 'Attention': child.forward = contextual_forward(child, blur_sigma=blur_sigma) elif hasattr(child, 'children'): for sub_child in child.children(): apply_seg(unet, sub_child, blur_sigma=blur_sigma) @spaces.GPU def sample(prompt, cfg = 7.5, blur_sigma = blur_inf, seed=123, steps=50, signal_scale = 1.0): pipe = StableDiffusionPipeline.from_pretrained(model_id) scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = pipe.to("cuda") guidance_scale = cfg num_inference_steps = steps scheduler.set_timesteps(num_inference_steps, device="cuda") timesteps = scheduler.timesteps shape = (1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size) latents = torch.randn(shape, generator=torch.Generator(device="cuda").manual_seed(seed), device="cuda").to("cuda") prompt_cond = pipe.encode_prompt(prompt=prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] prompt_uncond = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] prompt_embeds_combined = torch.cat([prompt_uncond, prompt_cond, prompt_cond]) with torch.no_grad(): apply_seg(pipe.unet, blur_sigma=blur_sigma) for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"): latent_model_input = torch.cat([latents] * 3) noise_pred = pipe.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds_combined, cross_attention_kwargs=None, return_dict=False, )[0] noise_pred_uncond, noise_pred_cond, noise_pred_cond_perturb = noise_pred.chunk(3) # Classfier Free Guidance noise_cfg = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # Smoothed Energy Guidance noise_pred = noise_cfg + signal_scale * (noise_pred_cond - noise_pred_cond_perturb) latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] image_cpu = image.squeeze(0).float().permute(1, 2, 0).detach().cpu() image_cpu = (image_cpu / 2 + 0.5).clamp(0, 1) return (image_cpu, cfg, blur_sigma) def gradio_generate(prompt: str, cfg: float, blur_sigma: float): # run your sample() and get back a torch.Tensor in [0–1] image_tensor, _, _ = sample(prompt, cfg=cfg, blur_sigma=blur_sigma, seed=123, steps=50, signal_scale=1.0) # to HxWxC uint8 image_np = (image_tensor.cpu().numpy() * 255).round().astype(np.uint8) return image_np examples = [ "A realistic photo of a woman, cyberpunk outfit, neon lighting, wall background", "a painting of a bouquet of flowers, likely pansies, arranged in a vase. The painting style appears to be impressionistic, characterized by visible brushstrokes and a focus on capturing the overall impression rather than fine detail", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", ] css=""" #col-container { margin: 0 auto; max-width: 580px; } """ with gr.Blocks(css=css) as demo: gr.HTML( """
Smoothed Energy Guidance in Stable Diffusion 2.1