AbstractPhil's picture
Update app.py
e543e33 verified
raw
history blame
14.4 kB
# app.py ────────────────────────────────────────────────────────────────────
import io, os, json, math, random, warnings, gc, functools, hashlib
from pathlib import Path
from typing import Dict, List, Optional
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn.functional as F
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import (
StableDiffusionXLPipeline,
DDIMScheduler,
EulerDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# -------------------------------------------------------------------------
# local modules
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from configs import T5_SHUNT_REPOS
from embedding_manager import get_bank # ← NEW
warnings.filterwarnings("ignore")
# ───────────────────────────────────────────────────────────────────────────
# GLOBALS
# ───────────────────────────────────────────────────────────────────────────
dtype = torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bank = get_bank() # shared singleton
_t5_tok: Optional[T5Tokenizer] = None
_t5_mod: Optional[T5EncoderModel] = None
_pipe: Optional[StableDiffusionXLPipeline] = None
SCHEDULERS = {
"DPM++ 2M": DPMSolverMultistepScheduler,
"DDIM": DDIMScheduler,
"Euler": EulerDiscreteScheduler,
}
# easy access to adapter repo metadata
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
conf_l = T5_SHUNT_REPOS["clip_l"]["config"]
conf_g = T5_SHUNT_REPOS["clip_g"]["config"]
# ───────────────────────────────────────────────────────────────────────────
# HELPERs
# ───────────────────────────────────────────────────────────────────────────
def _init_t5():
global _t5_tok, _t5_mod
if _t5_tok is None:
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
def _init_pipe():
global _pipe
if _pipe is None:
_pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype,
use_safetensors=True,
variant="fp16",
).to(device)
_pipe.enable_xformers_memory_efficient_attention()
def load_adapter(repo: str, filename: str, cfg: dict):
"""load a TwoStreamShuntAdapter from HF Hub safetensors"""
path = hf_hub_download(repo_id=repo, filename=filename)
model = TwoStreamShuntAdapter(cfg).eval()
tensors = load_file(path)
model.load_state_dict(tensors)
return model.to(device)
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
if isinstance(mat, torch.Tensor):
mat = mat.detach().cpu().numpy()
if mat.ndim == 1:
mat = mat[None, :]
elif mat.ndim >= 3: # (B,T,D) β†’ mean over B
mat = mat.mean(axis=0)
plt.figure(figsize=(8, 4), dpi=120)
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
plt.title(title)
plt.colorbar(shrink=0.7)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close()
buf.seek(0)
return np.array(Image.open(buf))
def encode_prompt_sd_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
"""Return CLIP-L, CLIP-G (and negative) embeddings from SDXL pipeline."""
tok_l = pipe.tokenizer(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
tok_g = pipe.tokenizer_2(prompt,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
ntok_l = pipe.tokenizer(negative, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
ntok_g = pipe.tokenizer_2(negative,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
clip_l = pipe.text_encoder(tok_l)[0] # (1,77,768)
nclip_l= pipe.text_encoder(ntok_l)[0]
out_g = pipe.text_encoder_2(tok_g, output_hidden_states=False)
clip_g, pooled = out_g[1], out_g[0]
nout_g = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
nclip_g, npooled = nout_g[1], nout_g[0]
return {"clip_l": clip_l, "clip_g": clip_g,
"neg_l": nclip_l, "neg_g": nclip_g,
"pooled": pooled, "neg_pooled": npooled}
def adapter_forward(adapter, t5_seq, clip_seq, cfg):
with torch.no_grad():
out = adapter(t5_seq.float(), clip_seq.float())
# unify outputs
anchor, delta, log_sigma, *_, tau, g_pred, gate = (
out + (None,) * 8)[:8] # pad to length 8
delta = delta * cfg["delta_scale"]
gate = torch.sigmoid(gate * g_pred * cfg["gpred_scale"]) * cfg["gate_prob"]
final_delta = delta * cfg["strength"] * gate
mod = clip_seq + final_delta.to(dtype)
if cfg["sigma_scale"] > 0:
sigma = torch.exp(log_sigma * cfg["sigma_scale"])
mod += torch.randn_like(mod) * sigma.to(dtype)
if cfg["use_anchor"]:
mod = mod * (1 - gate) + anchor.to(dtype) * gate
if cfg["noise"] > 0:
mod += torch.randn_like(mod) * cfg["noise"]
return mod, final_delta, gate, g_pred, tau
# ───────────────────────────────────────────────────────────────────────────
# MAIN INFERENCE
# ───────────────────────────────────────────────────────────────────────────
def infer(prompt, negative_prompt,
adapter_l_file, adapter_g_file,
strength, delta_scale, sigma_scale,
gpred_scale, noise, gate_prob, use_anchor,
steps, cfg_scale, scheduler_name,
width, height, seed):
torch.cuda.empty_cache()
_init_t5(); _init_pipe()
# scheduler
if scheduler_name in SCHEDULERS:
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
# RNG
generator = None
if seed != -1:
generator = torch.Generator(device=device).manual_seed(seed)
torch.manual_seed(seed); np.random.seed(seed)
# T5 embeddings (semantic guidance)
t5_ids = _t5_tok(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
t5_seq = _t5_mod(t5_ids).last_hidden_state # (1,77,768)
# CLIP embeddings from SDXL
embeds = encode_prompt_sd_xl(_pipe, prompt, negative_prompt)
# ------------------------------------------------------------------
# LOAD adapters (if any)
cfg_common = dict(
strength=strength, delta_scale=delta_scale, sigma_scale=sigma_scale,
gpred_scale=gpred_scale, noise=noise, gate_prob=gate_prob,
use_anchor=use_anchor,
)
# CLIP-L
if adapter_l_file and adapter_l_file != "None":
cfg_l = conf_l.copy(); cfg_l.update(cfg_common)
if "booru" in adapter_l_file: cfg_l["heads"] = 4
adapter_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
clip_l_mod, delta_l, gate_l, g_pred_l, tau_l = adapter_forward(
adapter_l, t5_seq, embeds["clip_l"], cfg_l)
else:
clip_l_mod = embeds["clip_l"]; delta_l = torch.zeros_like(clip_l_mod)
gate_l = torch.zeros_like(clip_l_mod[..., :1]); g_pred_l = tau_l = torch.tensor(0.)
# CLIP-G
if adapter_g_file and adapter_g_file != "None":
cfg_g = conf_g.copy(); cfg_g.update(cfg_common)
adapter_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
clip_g_mod, delta_g, gate_g, g_pred_g, tau_g = adapter_forward(
adapter_g, t5_seq, embeds["clip_g"], cfg_g)
else:
clip_g_mod = embeds["clip_g"]; delta_g = torch.zeros_like(clip_g_mod)
gate_g = torch.zeros_like(clip_g_mod[..., :1]); g_pred_g = tau_g = torch.tensor(0.)
# concatenate for SDXL
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
# SDXL generation
image = _pipe(
prompt_embeds = prompt_embeds,
negative_prompt_embeds = neg_embeds,
pooled_prompt_embeds = embeds["pooled"],
negative_pooled_prompt_embeds = embeds["neg_pooled"],
num_inference_steps=steps, guidance_scale=cfg_scale,
width=width, height=height, generator=generator
).images[0]
# viz
delta_l_img = plot_heat(delta_l.squeeze(), "Ξ” CLIP-L")
gate_l_img = plot_heat(gate_l.squeeze().mean(-1, keepdims=True), "Gate L")
delta_g_img = plot_heat(delta_g.squeeze(), "Ξ” CLIP-G")
gate_g_img = plot_heat(gate_g.squeeze().mean(-1, keepdims=True), "Gate G")
stats_l = f"g_pred_L={g_pred_l.item():.3f} | Ο„_L={tau_l.item():.3f}"
stats_g = f"g_pred_G={g_pred_g.item():.3f} | Ο„_G={tau_g.item():.3f}"
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
# ───────────────────────────────────────────────────────────────────────────
# GRADIO UI
# ───────────────────────────────────────────────────────────────────────────
def create_interface():
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Prompts")
prompt = gr.Textbox(label="Prompt", lines=3,
value="a futuristic control station with holographic displays")
negative_prompt = gr.Textbox(label="Negative", lines=2,
value="blurry, low quality, distorted")
gr.Markdown("### Adapters")
adapter_l = gr.Dropdown(["None"]+clip_l_opts, value="t5-vit-l-14-dual_shunt_caption.safetensors",
label="CLIP-L Adapter")
adapter_g = gr.Dropdown(["None"]+clip_g_opts, value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
label="CLIP-G Adapter")
gr.Markdown("### Adapter Controls")
strength = gr.Slider(0, 10, 4.0, 0.01, label="Strength")
delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ” scale")
sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Οƒ scale")
gpred_scale = gr.Slider(0, 20, 2.0, 0.01, label="g_pred scale")
noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
use_anchor = gr.Checkbox(True, label="Use anchor mix")
gr.Markdown("### Generation")
with gr.Row():
steps = gr.Slider(1, 50, 20, 1, label="Steps")
cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler")
with gr.Row():
width = gr.Slider(512, 1536, 1024, 64, label="Width")
height = gr.Slider(512, 1536, 1024, 64, label="Height")
seed = gr.Number(-1, label="Seed (-1=random)")
go_btn = gr.Button("πŸš€ Generate", variant="primary")
with gr.Column(scale=1):
out_img = gr.Image(label="Result", height=400)
gr.Markdown("### Adapter Diagnostics")
delta_l_i = gr.Image(label="Ξ” L", height=180)
gate_l_i = gr.Image(label="Gate L", height=180)
delta_g_i = gr.Image(label="Ξ” G", height=180)
gate_g_i = gr.Image(label="Gate G", height=180)
stats_l = gr.Textbox(label="Stats L", interactive=False)
stats_g = gr.Textbox(label="Stats G", interactive=False)
def _run(*args):
pl , npl = args[0], args[1]
al, ag = (None if v=="None" else v for v in args[2:4])
return infer(pl, npl, al, ag, *args[4:])
go_btn.click(
_run,
inputs=[prompt, negative_prompt, adapter_l, adapter_g,
strength, delta_scale, sigma_scale, gpred_scale,
noise, gate_prob, use_anchor, steps, cfg_scale,
scheduler, width, height, seed],
outputs=[out_img, delta_l_i, gate_l_i, delta_g_i, gate_g_i,
stats_l, stats_g]
)
return demo
# ───────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
create_interface().launch()