# app.py ──────────────────────────────────────────────────────────────────── import io, os, json, math, random, warnings, gc, functools, hashlib from pathlib import Path from typing import Dict, List, Optional import gradio as gr import numpy as np import matplotlib.pyplot as plt from PIL import Image import torch import torch.nn.functional as F from transformers import T5Tokenizer, T5EncoderModel from diffusers import ( StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, ) from huggingface_hub import hf_hub_download from safetensors.torch import load_file # ------------------------------------------------------------------------- # local modules from two_stream_shunt_adapter import TwoStreamShuntAdapter from configs import T5_SHUNT_REPOS from embedding_manager import get_bank # ← NEW warnings.filterwarnings("ignore") # ─────────────────────────────────────────────────────────────────────────── # GLOBALS # ─────────────────────────────────────────────────────────────────────────── dtype = torch.float16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") bank = get_bank() # shared singleton _t5_tok: Optional[T5Tokenizer] = None _t5_mod: Optional[T5EncoderModel] = None _pipe: Optional[StableDiffusionXLPipeline] = None SCHEDULERS = { "DPM++ 2M": DPMSolverMultistepScheduler, "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, } # easy access to adapter repo metadata clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] conf_l = T5_SHUNT_REPOS["clip_l"]["config"] conf_g = T5_SHUNT_REPOS["clip_g"]["config"] # ─────────────────────────────────────────────────────────────────────────── # HELPERs # ─────────────────────────────────────────────────────────────────────────── def _init_t5(): global _t5_tok, _t5_mod if _t5_tok is None: _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() def _init_pipe(): global _pipe if _pipe is None: _pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, use_safetensors=True, variant="fp16", ).to(device) _pipe.enable_xformers_memory_efficient_attention() def load_adapter(repo: str, filename: str, cfg: dict): """load a TwoStreamShuntAdapter from HF Hub safetensors""" path = hf_hub_download(repo_id=repo, filename=filename) model = TwoStreamShuntAdapter(cfg).eval() tensors = load_file(path) model.load_state_dict(tensors) return model.to(device) def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray: if isinstance(mat, torch.Tensor): mat = mat.detach().cpu().numpy() if mat.ndim == 1: mat = mat[None, :] elif mat.ndim >= 3: # (B,T,D) → mean over B mat = mat.mean(axis=0) plt.figure(figsize=(8, 4), dpi=120) plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper") plt.title(title) plt.colorbar(shrink=0.7) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png") plt.close() buf.seek(0) return np.array(Image.open(buf)) def encode_prompt_sd_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]: """Return CLIP-L, CLIP-G (and negative) embeddings from SDXL pipeline.""" tok_l = pipe.tokenizer(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device) tok_g = pipe.tokenizer_2(prompt,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device) ntok_l = pipe.tokenizer(negative, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device) ntok_g = pipe.tokenizer_2(negative,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device) with torch.no_grad(): clip_l = pipe.text_encoder(tok_l)[0] # (1,77,768) nclip_l= pipe.text_encoder(ntok_l)[0] out_g = pipe.text_encoder_2(tok_g, output_hidden_states=False) clip_g, pooled = out_g[1], out_g[0] nout_g = pipe.text_encoder_2(ntok_g, output_hidden_states=False) nclip_g, npooled = nout_g[1], nout_g[0] return {"clip_l": clip_l, "clip_g": clip_g, "neg_l": nclip_l, "neg_g": nclip_g, "pooled": pooled, "neg_pooled": npooled} def adapter_forward(adapter, t5_seq, clip_seq, cfg): with torch.no_grad(): out = adapter(t5_seq.float(), clip_seq.float()) # unify outputs anchor, delta, log_sigma, *_, tau, g_pred, gate = ( out + (None,) * 8)[:8] # pad to length 8 delta = delta * cfg["delta_scale"] gate = torch.sigmoid(gate * g_pred * cfg["gpred_scale"]) * cfg["gate_prob"] final_delta = delta * cfg["strength"] * gate mod = clip_seq + final_delta.to(dtype) if cfg["sigma_scale"] > 0: sigma = torch.exp(log_sigma * cfg["sigma_scale"]) mod += torch.randn_like(mod) * sigma.to(dtype) if cfg["use_anchor"]: mod = mod * (1 - gate) + anchor.to(dtype) * gate if cfg["noise"] > 0: mod += torch.randn_like(mod) * cfg["noise"] return mod, final_delta, gate, g_pred, tau # ─────────────────────────────────────────────────────────────────────────── # MAIN INFERENCE # ─────────────────────────────────────────────────────────────────────────── def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, delta_scale, sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler_name, width, height, seed): torch.cuda.empty_cache() _init_t5(); _init_pipe() # scheduler if scheduler_name in SCHEDULERS: _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config) # RNG generator = None if seed != -1: generator = torch.Generator(device=device).manual_seed(seed) torch.manual_seed(seed); np.random.seed(seed) # T5 embeddings (semantic guidance) t5_ids = _t5_tok(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) t5_seq = _t5_mod(t5_ids).last_hidden_state # (1,77,768) # CLIP embeddings from SDXL embeds = encode_prompt_sd_xl(_pipe, prompt, negative_prompt) # ------------------------------------------------------------------ # LOAD adapters (if any) cfg_common = dict( strength=strength, delta_scale=delta_scale, sigma_scale=sigma_scale, gpred_scale=gpred_scale, noise=noise, gate_prob=gate_prob, use_anchor=use_anchor, ) # CLIP-L if adapter_l_file and adapter_l_file != "None": cfg_l = conf_l.copy(); cfg_l.update(cfg_common) if "booru" in adapter_l_file: cfg_l["heads"] = 4 adapter_l = load_adapter(repo_l, adapter_l_file, conf_l, device) clip_l_mod, delta_l, gate_l, g_pred_l, tau_l = adapter_forward( adapter_l, t5_seq, embeds["clip_l"], cfg_l) else: clip_l_mod = embeds["clip_l"]; delta_l = torch.zeros_like(clip_l_mod) gate_l = torch.zeros_like(clip_l_mod[..., :1]); g_pred_l = tau_l = torch.tensor(0.) # CLIP-G if adapter_g_file and adapter_g_file != "None": cfg_g = conf_g.copy(); cfg_g.update(cfg_common) adapter_g = load_adapter(repo_g, adapter_g_file, conf_g, device) clip_g_mod, delta_g, gate_g, g_pred_g, tau_g = adapter_forward( adapter_g, t5_seq, embeds["clip_g"], cfg_g) else: clip_g_mod = embeds["clip_g"]; delta_g = torch.zeros_like(clip_g_mod) gate_g = torch.zeros_like(clip_g_mod[..., :1]); g_pred_g = tau_g = torch.tensor(0.) # concatenate for SDXL prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1) # SDXL generation image = _pipe( prompt_embeds = prompt_embeds, negative_prompt_embeds = neg_embeds, pooled_prompt_embeds = embeds["pooled"], negative_pooled_prompt_embeds = embeds["neg_pooled"], num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator ).images[0] # viz delta_l_img = plot_heat(delta_l.squeeze(), "Δ CLIP-L") gate_l_img = plot_heat(gate_l.squeeze().mean(-1, keepdims=True), "Gate L") delta_g_img = plot_heat(delta_g.squeeze(), "Δ CLIP-G") gate_g_img = plot_heat(gate_g.squeeze().mean(-1, keepdims=True), "Gate G") stats_l = f"g_pred_L={g_pred_l.item():.3f} | τ_L={tau_l.item():.3f}" stats_g = f"g_pred_G={g_pred_g.item():.3f} | τ_G={tau_g.item():.3f}" return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g # ─────────────────────────────────────────────────────────────────────────── # GRADIO UI # ─────────────────────────────────────────────────────────────────────────── def create_interface(): with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 SDXL Dual-Shunt Tester") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Prompts") prompt = gr.Textbox(label="Prompt", lines=3, value="a futuristic control station with holographic displays") negative_prompt = gr.Textbox(label="Negative", lines=2, value="blurry, low quality, distorted") gr.Markdown("### Adapters") adapter_l = gr.Dropdown(["None"]+clip_l_opts, value="t5-vit-l-14-dual_shunt_caption.safetensors", label="CLIP-L Adapter") adapter_g = gr.Dropdown(["None"]+clip_g_opts, value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors", label="CLIP-G Adapter") gr.Markdown("### Adapter Controls") strength = gr.Slider(0, 10, 4.0, 0.01, label="Strength") delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Δ scale") sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="σ scale") gpred_scale = gr.Slider(0, 20, 2.0, 0.01, label="g_pred scale") noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise") gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob") use_anchor = gr.Checkbox(True, label="Use anchor mix") gr.Markdown("### Generation") with gr.Row(): steps = gr.Slider(1, 50, 20, 1, label="Steps") cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG") scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler") with gr.Row(): width = gr.Slider(512, 1536, 1024, 64, label="Width") height = gr.Slider(512, 1536, 1024, 64, label="Height") seed = gr.Number(-1, label="Seed (-1=random)") go_btn = gr.Button("🚀 Generate", variant="primary") with gr.Column(scale=1): out_img = gr.Image(label="Result", height=400) gr.Markdown("### Adapter Diagnostics") delta_l_i = gr.Image(label="Δ L", height=180) gate_l_i = gr.Image(label="Gate L", height=180) delta_g_i = gr.Image(label="Δ G", height=180) gate_g_i = gr.Image(label="Gate G", height=180) stats_l = gr.Textbox(label="Stats L", interactive=False) stats_g = gr.Textbox(label="Stats G", interactive=False) def _run(*args): pl , npl = args[0], args[1] al, ag = (None if v=="None" else v for v in args[2:4]) return infer(pl, npl, al, ag, *args[4:]) go_btn.click( _run, inputs=[prompt, negative_prompt, adapter_l, adapter_g, strength, delta_scale, sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler, width, height, seed], outputs=[out_img, delta_l_i, gate_l_i, delta_g_i, gate_g_i, stats_l, stats_g] ) return demo # ─────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": create_interface().launch()