Spaces:
Runtime error
Runtime error
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import io, os, json, math, random, warnings, gc, functools, hashlib | |
from pathlib import Path | |
from typing import Dict, List, Optional | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from transformers import T5Tokenizer, T5EncoderModel | |
from diffusers import ( | |
StableDiffusionXLPipeline, | |
DDIMScheduler, | |
EulerDiscreteScheduler, | |
DPMSolverMultistepScheduler, | |
) | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
# ------------------------------------------------------------------------- | |
# local modules | |
from two_stream_shunt_adapter import TwoStreamShuntAdapter | |
from configs import T5_SHUNT_REPOS | |
from embedding_manager import get_bank # β NEW | |
warnings.filterwarnings("ignore") | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# GLOBALS | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
dtype = torch.float16 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
bank = get_bank() # shared singleton | |
_t5_tok: Optional[T5Tokenizer] = None | |
_t5_mod: Optional[T5EncoderModel] = None | |
_pipe: Optional[StableDiffusionXLPipeline] = None | |
SCHEDULERS = { | |
"DPM++ 2M": DPMSolverMultistepScheduler, | |
"DDIM": DDIMScheduler, | |
"Euler": EulerDiscreteScheduler, | |
} | |
# easy access to adapter repo metadata | |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] | |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] | |
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] | |
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] | |
conf_l = T5_SHUNT_REPOS["clip_l"]["config"] | |
conf_g = T5_SHUNT_REPOS["clip_g"]["config"] | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# HELPERs | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def _init_t5(): | |
global _t5_tok, _t5_mod | |
if _t5_tok is None: | |
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() | |
def _init_pipe(): | |
global _pipe | |
if _pipe is None: | |
_pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, | |
use_safetensors=True, | |
variant="fp16", | |
).to(device) | |
_pipe.enable_xformers_memory_efficient_attention() | |
def load_adapter(repo: str, filename: str, cfg: dict): | |
"""load a TwoStreamShuntAdapter from HF Hub safetensors""" | |
path = hf_hub_download(repo_id=repo, filename=filename) | |
model = TwoStreamShuntAdapter(cfg).eval() | |
tensors = load_file(path) | |
model.load_state_dict(tensors) | |
return model.to(device) | |
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray: | |
if isinstance(mat, torch.Tensor): | |
mat = mat.detach().cpu().numpy() | |
if mat.ndim == 1: | |
mat = mat[None, :] | |
elif mat.ndim >= 3: # (B,T,D) β mean over B | |
mat = mat.mean(axis=0) | |
plt.figure(figsize=(8, 4), dpi=120) | |
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper") | |
plt.title(title) | |
plt.colorbar(shrink=0.7) | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png") | |
plt.close() | |
buf.seek(0) | |
return np.array(Image.open(buf)) | |
def encode_prompt_sd_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]: | |
"""Return CLIP-L, CLIP-G (and negative) embeddings from SDXL pipeline.""" | |
tok_l = pipe.tokenizer(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device) | |
tok_g = pipe.tokenizer_2(prompt,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device) | |
ntok_l = pipe.tokenizer(negative, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device) | |
ntok_g = pipe.tokenizer_2(negative,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device) | |
with torch.no_grad(): | |
clip_l = pipe.text_encoder(tok_l)[0] # (1,77,768) | |
nclip_l= pipe.text_encoder(ntok_l)[0] | |
out_g = pipe.text_encoder_2(tok_g, output_hidden_states=False) | |
clip_g, pooled = out_g[1], out_g[0] | |
nout_g = pipe.text_encoder_2(ntok_g, output_hidden_states=False) | |
nclip_g, npooled = nout_g[1], nout_g[0] | |
return {"clip_l": clip_l, "clip_g": clip_g, | |
"neg_l": nclip_l, "neg_g": nclip_g, | |
"pooled": pooled, "neg_pooled": npooled} | |
def adapter_forward(adapter, t5_seq, clip_seq, cfg): | |
with torch.no_grad(): | |
out = adapter(t5_seq.float(), clip_seq.float()) | |
# unify outputs | |
anchor, delta, log_sigma, *_, tau, g_pred, gate = ( | |
out + (None,) * 8)[:8] # pad to length 8 | |
delta = delta * cfg["delta_scale"] | |
gate = torch.sigmoid(gate * g_pred * cfg["gpred_scale"]) * cfg["gate_prob"] | |
final_delta = delta * cfg["strength"] * gate | |
mod = clip_seq + final_delta.to(dtype) | |
if cfg["sigma_scale"] > 0: | |
sigma = torch.exp(log_sigma * cfg["sigma_scale"]) | |
mod += torch.randn_like(mod) * sigma.to(dtype) | |
if cfg["use_anchor"]: | |
mod = mod * (1 - gate) + anchor.to(dtype) * gate | |
if cfg["noise"] > 0: | |
mod += torch.randn_like(mod) * cfg["noise"] | |
return mod, final_delta, gate, g_pred, tau | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# MAIN INFERENCE | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def infer(prompt, negative_prompt, | |
adapter_l_file, adapter_g_file, | |
strength, delta_scale, sigma_scale, | |
gpred_scale, noise, gate_prob, use_anchor, | |
steps, cfg_scale, scheduler_name, | |
width, height, seed): | |
torch.cuda.empty_cache() | |
_init_t5(); _init_pipe() | |
# scheduler | |
if scheduler_name in SCHEDULERS: | |
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config) | |
# RNG | |
generator = None | |
if seed != -1: | |
generator = torch.Generator(device=device).manual_seed(seed) | |
torch.manual_seed(seed); np.random.seed(seed) | |
# T5 embeddings (semantic guidance) | |
t5_ids = _t5_tok(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) | |
t5_seq = _t5_mod(t5_ids).last_hidden_state # (1,77,768) | |
# CLIP embeddings from SDXL | |
embeds = encode_prompt_sd_xl(_pipe, prompt, negative_prompt) | |
# ------------------------------------------------------------------ | |
# LOAD adapters (if any) | |
cfg_common = dict( | |
strength=strength, delta_scale=delta_scale, sigma_scale=sigma_scale, | |
gpred_scale=gpred_scale, noise=noise, gate_prob=gate_prob, | |
use_anchor=use_anchor, | |
) | |
# CLIP-L | |
if adapter_l_file and adapter_l_file != "None": | |
cfg_l = conf_l.copy(); cfg_l.update(cfg_common) | |
if "booru" in adapter_l_file: cfg_l["heads"] = 4 | |
adapter_l = load_adapter(repo_l, adapter_l_file, conf_l, device) | |
clip_l_mod, delta_l, gate_l, g_pred_l, tau_l = adapter_forward( | |
adapter_l, t5_seq, embeds["clip_l"], cfg_l) | |
else: | |
clip_l_mod = embeds["clip_l"]; delta_l = torch.zeros_like(clip_l_mod) | |
gate_l = torch.zeros_like(clip_l_mod[..., :1]); g_pred_l = tau_l = torch.tensor(0.) | |
# CLIP-G | |
if adapter_g_file and adapter_g_file != "None": | |
cfg_g = conf_g.copy(); cfg_g.update(cfg_common) | |
adapter_g = load_adapter(repo_g, adapter_g_file, conf_g, device) | |
clip_g_mod, delta_g, gate_g, g_pred_g, tau_g = adapter_forward( | |
adapter_g, t5_seq, embeds["clip_g"], cfg_g) | |
else: | |
clip_g_mod = embeds["clip_g"]; delta_g = torch.zeros_like(clip_g_mod) | |
gate_g = torch.zeros_like(clip_g_mod[..., :1]); g_pred_g = tau_g = torch.tensor(0.) | |
# concatenate for SDXL | |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) | |
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1) | |
# SDXL generation | |
image = _pipe( | |
prompt_embeds = prompt_embeds, | |
negative_prompt_embeds = neg_embeds, | |
pooled_prompt_embeds = embeds["pooled"], | |
negative_pooled_prompt_embeds = embeds["neg_pooled"], | |
num_inference_steps=steps, guidance_scale=cfg_scale, | |
width=width, height=height, generator=generator | |
).images[0] | |
# viz | |
delta_l_img = plot_heat(delta_l.squeeze(), "Ξ CLIP-L") | |
gate_l_img = plot_heat(gate_l.squeeze().mean(-1, keepdims=True), "Gate L") | |
delta_g_img = plot_heat(delta_g.squeeze(), "Ξ CLIP-G") | |
gate_g_img = plot_heat(gate_g.squeeze().mean(-1, keepdims=True), "Gate G") | |
stats_l = f"g_pred_L={g_pred_l.item():.3f} | Ο_L={tau_l.item():.3f}" | |
stats_g = f"g_pred_G={g_pred_g.item():.3f} | Ο_G={tau_g.item():.3f}" | |
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# GRADIO UI | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def create_interface(): | |
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ SDXL Dual-Shunt Tester") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Prompts") | |
prompt = gr.Textbox(label="Prompt", lines=3, | |
value="a futuristic control station with holographic displays") | |
negative_prompt = gr.Textbox(label="Negative", lines=2, | |
value="blurry, low quality, distorted") | |
gr.Markdown("### Adapters") | |
adapter_l = gr.Dropdown(["None"]+clip_l_opts, value="t5-vit-l-14-dual_shunt_caption.safetensors", | |
label="CLIP-L Adapter") | |
adapter_g = gr.Dropdown(["None"]+clip_g_opts, value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors", | |
label="CLIP-G Adapter") | |
gr.Markdown("### Adapter Controls") | |
strength = gr.Slider(0, 10, 4.0, 0.01, label="Strength") | |
delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ scale") | |
sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Ο scale") | |
gpred_scale = gr.Slider(0, 20, 2.0, 0.01, label="g_pred scale") | |
noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise") | |
gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob") | |
use_anchor = gr.Checkbox(True, label="Use anchor mix") | |
gr.Markdown("### Generation") | |
with gr.Row(): | |
steps = gr.Slider(1, 50, 20, 1, label="Steps") | |
cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG") | |
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler") | |
with gr.Row(): | |
width = gr.Slider(512, 1536, 1024, 64, label="Width") | |
height = gr.Slider(512, 1536, 1024, 64, label="Height") | |
seed = gr.Number(-1, label="Seed (-1=random)") | |
go_btn = gr.Button("π Generate", variant="primary") | |
with gr.Column(scale=1): | |
out_img = gr.Image(label="Result", height=400) | |
gr.Markdown("### Adapter Diagnostics") | |
delta_l_i = gr.Image(label="Ξ L", height=180) | |
gate_l_i = gr.Image(label="Gate L", height=180) | |
delta_g_i = gr.Image(label="Ξ G", height=180) | |
gate_g_i = gr.Image(label="Gate G", height=180) | |
stats_l = gr.Textbox(label="Stats L", interactive=False) | |
stats_g = gr.Textbox(label="Stats G", interactive=False) | |
def _run(*args): | |
pl , npl = args[0], args[1] | |
al, ag = (None if v=="None" else v for v in args[2:4]) | |
return infer(pl, npl, al, ag, *args[4:]) | |
go_btn.click( | |
_run, | |
inputs=[prompt, negative_prompt, adapter_l, adapter_g, | |
strength, delta_scale, sigma_scale, gpred_scale, | |
noise, gate_prob, use_anchor, steps, cfg_scale, | |
scheduler, width, height, seed], | |
outputs=[out_img, delta_l_i, gate_l_i, delta_g_i, gate_g_i, | |
stats_l, stats_g] | |
) | |
return demo | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
if __name__ == "__main__": | |
create_interface().launch() | |