AbstractPhil's picture
Update app.py
19e2e87 verified
raw
history blame
13.3 kB
# app.py ────────────────────────────────────────────────────────────────
import io, warnings, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gradio as gr
import torch, torch.nn.functional as F
from PIL import Image
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import (
StableDiffusionXLPipeline,
DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# local modules
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput
from embedding_manager import get_bank
from configs import T5_SHUNT_REPOS
warnings.filterwarnings("ignore")
# ─── GLOBALS ────────────────────────────────────────────────────────────
dtype = torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_bank = get_bank() # singleton – optional caching
_t5_tok: Optional[T5Tokenizer] = None
_t5_mod: Optional[T5EncoderModel] = None
_pipe : Optional[StableDiffusionXLPipeline] = None
SCHEDULERS = {
"DPM++ 2M": DPMSolverMultistepScheduler,
"DDIM": DDIMScheduler,
"Euler": EulerDiscreteScheduler,
}
# adapter-meta from configs.py
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l, conf_l = T5_SHUNT_REPOS["clip_l"]["repo"], T5_SHUNT_REPOS["clip_l"]["config"]
repo_g, conf_g = T5_SHUNT_REPOS["clip_g"]["repo"], T5_SHUNT_REPOS["clip_g"]["config"]
# ─── INITIALISERS ────────────────────────────────────────────────────────
def _init_t5():
global _t5_tok, _t5_mod
if _t5_tok is None:
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base") \
.to(device).eval()
def _init_pipe():
global _pipe
if _pipe is None:
_pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype, variant="fp16", use_safetensors=True
).to(device)
_pipe.enable_xformers_memory_efficient_attention()
# ─── HELPERS ─────────────────────────────────────────────────────────────
def load_adapter(repo: str, filename: str, cfg: dict,
device: torch.device) -> TwoStreamShuntAdapter:
path = hf_hub_download(repo_id=repo, filename=filename)
model = TwoStreamShuntAdapter(cfg).eval()
model.load_state_dict(load_file(path))
return model.to(device)
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
if isinstance(mat, torch.Tensor):
mat = mat.detach().cpu().numpy()
if mat.ndim == 1:
mat = mat[None, :]
elif mat.ndim >= 3:
mat = mat.mean(axis=0)
plt.figure(figsize=(7, 3.3), dpi=110)
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
plt.title(title, fontsize=10)
plt.colorbar(shrink=0.7)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close(); buf.seek(0)
return np.array(Image.open(buf))
def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
tok_l = pipe.tokenizer (prompt, max_length=77, truncation=True,
padding="max_length", return_tensors="pt").input_ids.to(device)
tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True,
padding="max_length", return_tensors="pt").input_ids.to(device)
ntok_l = pipe.tokenizer (negative,max_length=77, truncation=True,
padding="max_length", return_tensors="pt").input_ids.to(device)
ntok_g = pipe.tokenizer_2(negative,max_length=77, truncation=True,
padding="max_length", return_tensors="pt").input_ids.to(device)
with torch.no_grad():
clip_l = pipe.text_encoder(tok_l)[0]
neg_clip_l = pipe.text_encoder(ntok_l)[0]
g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False)
clip_g, pl = g_out[1], g_out[0]
ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
neg_clip_g, npl = ng_out[1], ng_out[0]
return {"clip_l": clip_l, "clip_g": clip_g,
"neg_l": neg_clip_l, "neg_g": neg_clip_g,
"pooled": pl, "neg_pooled": npl}
# ─── INFERENCE ───────────────────────────────────────────────────────────
def infer(prompt: str, negative_prompt: str,
adapter_l_file: str, adapter_g_file: str,
strength: float, delta_scale: float, sigma_scale: float,
gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool,
steps: int, cfg_scale: float, scheduler_name: str,
width: int, height: int, seed: int):
torch.cuda.empty_cache()
_init_t5(); _init_pipe()
if scheduler_name in SCHEDULERS:
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
generator = (torch.Generator(device=device).manual_seed(seed)
if seed != -1 else None)
# build ShiftConfig (one per request)
cfg_shift = ShiftConfig(
prompt = prompt,
seed = seed,
strength = strength,
delta_scale = delta_scale,
sigma_scale = sigma_scale,
gate_probability = gate_prob,
noise_injection = noise,
use_anchor = use_anchor,
guidance_scale = gpred_scale,
)
# encoder (T5) embeddings
t5_seq = ConditioningShifter.extract_encoder_embeddings(
{"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
device, cfg_shift
)
# CLIP embeddings
embeds = encode_prompt_xl(_pipe, prompt, negative_prompt)
# run adapters --------------------------------------------------------
outputs: List[AdapterOutput] = []
if adapter_l_file and adapter_l_file != "None":
ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
outputs.append(ConditioningShifter.run_adapter(
ada_l, t5_seq, embeds["clip_l"],
cfg_shift.guidance_scale, "clip_l", (0, 768)))
if adapter_g_file and adapter_g_file != "None":
ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
outputs.append(ConditioningShifter.run_adapter(
ada_g, t5_seq, embeds["clip_g"],
cfg_shift.guidance_scale, "clip_g", (768, 2048)))
# apply modifications -------------------------------------------------
clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
delta_viz = {"clip_l": torch.zeros_like(clip_l_mod),
"clip_g": torch.zeros_like(clip_g_mod)}
gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]),
"clip_g": torch.zeros_like(clip_g_mod[..., :1])}
for out in outputs:
target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift)
if out.adapter_type == "clip_l":
clip_l_mod = mod
else:
clip_g_mod = mod
delta_viz[out.adapter_type] = out.delta.detach()
gate_viz [out.adapter_type] = out.gate.detach()
# prepare for SDXL ----------------------------------------------------
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
image = _pipe(
prompt_embeds = prompt_embeds,
negative_prompt_embeds = neg_embeds,
pooled_prompt_embeds = embeds["pooled"],
negative_pooled_prompt_embeds = embeds["neg_pooled"],
num_inference_steps = steps,
guidance_scale = cfg_scale,
width = width, height = height, generator = generator
).images[0]
# diagnostics ---------------------------------------------------------
delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Ξ” CLIP-L")
gate_l_img = plot_heat(gate_viz ["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L")
delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Ξ” CLIP-G")
gate_g_img = plot_heat(gate_viz ["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G")
stats_l = (f"Ο„Μ„_L = {outputs[0].tau.mean().item():.3f}"
if outputs and outputs[0].adapter_type == "clip_l" else "-")
stats_g = (f"Ο„Μ„_G = {outputs[-1].tau.mean().item():.3f}"
if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-")
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
# ─── GRADIO UI ────────────────────────────────────────────────────────────
def create_interface():
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Prompts")
prompt = gr.Textbox(label="Prompt", lines=3,
value="a futuristic control station with holographic displays")
negative = gr.Textbox(label="Negative", lines=2,
value="blurry, low quality, distorted")
gr.Markdown("### Adapters")
adapter_l = gr.Dropdown(["None"] + clip_l_opts,
value="t5-vit-l-14-dual_shunt_caption.safetensors",
label="CLIP-L Adapter")
adapter_g = gr.Dropdown(["None"] + clip_g_opts,
value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
label="CLIP-G Adapter")
gr.Markdown("### Adapter Controls")
strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength")
delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ” scale")
sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Οƒ scale")
gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale")
noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
use_anchor = gr.Checkbox(True, label="Use anchor mix")
gr.Markdown("### Generation")
with gr.Row():
steps = gr.Slider(1, 50, 20, 1, label="Steps")
cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
scheduler = gr.Dropdown(list(SCHEDULERS.keys()),
value="DPM++ 2M", label="Scheduler")
with gr.Row():
width = gr.Slider(512, 1536, 1024, 64, label="Width")
height = gr.Slider(512, 1536, 1024, 64, label="Height")
seed = gr.Number(-1, label="Seed (-1 β†’ random)", precision=0)
run_btn = gr.Button("πŸš€ Generate", variant="primary")
with gr.Column(scale=1):
out_img = gr.Image(label="Result", height=400)
gr.Markdown("### Diagnostics")
delta_l = gr.Image(label="Ξ” L", height=180)
gate_l = gr.Image(label="Gate L", height=180)
delta_g = gr.Image(label="Ξ” G", height=180)
gate_g = gr.Image(label="Gate G", height=180)
stats_l = gr.Textbox(label="Stats L", interactive=False)
stats_g = gr.Textbox(label="Stats G", interactive=False)
def _run(*args):
pl, npl = args[0], args[1]
al, ag = (None if v == "None" else v for v in args[2:4])
return infer(pl, npl, al, ag, *args[4:])
run_btn.click(
fn=_run,
inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale,
sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps,
cfg_scale, scheduler, width, height, seed],
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
)
return demo
if __name__ == "__main__":
create_interface().launch()