# app.py ──────────────────────────────────────────────────────────────── import io, warnings, numpy as np, matplotlib.pyplot as plt from pathlib import Path from typing import Dict, List, Optional, Tuple import gradio as gr import torch, torch.nn.functional as F from PIL import Image from transformers import T5Tokenizer, T5EncoderModel from diffusers import ( StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, ) from huggingface_hub import hf_hub_download from safetensors.torch import load_file # local modules from two_stream_shunt_adapter import TwoStreamShuntAdapter from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput from embedding_manager import get_bank from configs import T5_SHUNT_REPOS warnings.filterwarnings("ignore") # ─── GLOBALS ──────────────────────────────────────────────────────────── dtype = torch.float16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _bank = get_bank() # singleton – optional caching _t5_tok: Optional[T5Tokenizer] = None _t5_mod: Optional[T5EncoderModel] = None _pipe : Optional[StableDiffusionXLPipeline] = None SCHEDULERS = { "DPM++ 2M": DPMSolverMultistepScheduler, "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, } # adapter-meta from configs.py clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] repo_l, conf_l = T5_SHUNT_REPOS["clip_l"]["repo"], T5_SHUNT_REPOS["clip_l"]["config"] repo_g, conf_g = T5_SHUNT_REPOS["clip_g"]["repo"], T5_SHUNT_REPOS["clip_g"]["config"] # ─── INITIALISERS ──────────────────────────────────────────────────────── def _init_t5(): global _t5_tok, _t5_mod if _t5_tok is None: _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base") \ .to(device).eval() def _init_pipe(): global _pipe if _pipe is None: _pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16", use_safetensors=True ).to(device) _pipe.enable_xformers_memory_efficient_attention() # ─── HELPERS ───────────────────────────────────────────────────────────── def load_adapter(repo: str, filename: str, cfg: dict, device: torch.device) -> TwoStreamShuntAdapter: path = hf_hub_download(repo_id=repo, filename=filename) model = TwoStreamShuntAdapter(cfg).eval() model.load_state_dict(load_file(path)) return model.to(device) def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray: if isinstance(mat, torch.Tensor): mat = mat.detach().cpu().numpy() if mat.ndim == 1: mat = mat[None, :] elif mat.ndim >= 3: mat = mat.mean(axis=0) plt.figure(figsize=(7, 3.3), dpi=110) plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper") plt.title(title, fontsize=10) plt.colorbar(shrink=0.7) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight") plt.close(); buf.seek(0) return np.array(Image.open(buf)) def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]: tok_l = pipe.tokenizer (prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) ntok_l = pipe.tokenizer (negative,max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) ntok_g = pipe.tokenizer_2(negative,max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) with torch.no_grad(): clip_l = pipe.text_encoder(tok_l)[0] neg_clip_l = pipe.text_encoder(ntok_l)[0] g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False) clip_g, pl = g_out[1], g_out[0] ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False) neg_clip_g, npl = ng_out[1], ng_out[0] return {"clip_l": clip_l, "clip_g": clip_g, "neg_l": neg_clip_l, "neg_g": neg_clip_g, "pooled": pl, "neg_pooled": npl} # ─── INFERENCE ─────────────────────────────────────────────────────────── def infer(prompt: str, negative_prompt: str, adapter_l_file: str, adapter_g_file: str, strength: float, delta_scale: float, sigma_scale: float, gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool, steps: int, cfg_scale: float, scheduler_name: str, width: int, height: int, seed: int): torch.cuda.empty_cache() _init_t5(); _init_pipe() if scheduler_name in SCHEDULERS: _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config) generator = (torch.Generator(device=device).manual_seed(seed) if seed != -1 else None) # build ShiftConfig (one per request) cfg_shift = ShiftConfig( prompt = prompt, seed = seed, strength = strength, delta_scale = delta_scale, sigma_scale = sigma_scale, gate_probability = gate_prob, noise_injection = noise, use_anchor = use_anchor, guidance_scale = gpred_scale, ) # encoder (T5) embeddings t5_seq = ConditioningShifter.extract_encoder_embeddings( {"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}}, device, cfg_shift ) # CLIP embeddings embeds = encode_prompt_xl(_pipe, prompt, negative_prompt) # run adapters -------------------------------------------------------- outputs: List[AdapterOutput] = [] if adapter_l_file and adapter_l_file != "None": ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device) outputs.append(ConditioningShifter.run_adapter( ada_l, t5_seq, embeds["clip_l"], cfg_shift.guidance_scale, "clip_l", (0, 768))) if adapter_g_file and adapter_g_file != "None": ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device) outputs.append(ConditioningShifter.run_adapter( ada_g, t5_seq, embeds["clip_g"], cfg_shift.guidance_scale, "clip_g", (768, 2048))) # apply modifications ------------------------------------------------- clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"] delta_viz = {"clip_l": torch.zeros_like(clip_l_mod), "clip_g": torch.zeros_like(clip_g_mod)} gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]), "clip_g": torch.zeros_like(clip_g_mod[..., :1])} for out in outputs: target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift) if out.adapter_type == "clip_l": clip_l_mod = mod else: clip_g_mod = mod delta_viz[out.adapter_type] = out.delta.detach() gate_viz [out.adapter_type] = out.gate.detach() # prepare for SDXL ---------------------------------------------------- prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1) image = _pipe( prompt_embeds = prompt_embeds, negative_prompt_embeds = neg_embeds, pooled_prompt_embeds = embeds["pooled"], negative_pooled_prompt_embeds = embeds["neg_pooled"], num_inference_steps = steps, guidance_scale = cfg_scale, width = width, height = height, generator = generator ).images[0] # diagnostics --------------------------------------------------------- delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Δ CLIP-L") gate_l_img = plot_heat(gate_viz ["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L") delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Δ CLIP-G") gate_g_img = plot_heat(gate_viz ["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G") stats_l = (f"τ̄_L = {outputs[0].tau.mean().item():.3f}" if outputs and outputs[0].adapter_type == "clip_l" else "-") stats_g = (f"τ̄_G = {outputs[-1].tau.mean().item():.3f}" if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-") return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g # ─── GRADIO UI ──────────────────────────────────────────────────────────── def create_interface(): with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 SDXL Dual-Shunt Tester") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Prompts") prompt = gr.Textbox(label="Prompt", lines=3, value="a futuristic control station with holographic displays") negative = gr.Textbox(label="Negative", lines=2, value="blurry, low quality, distorted") gr.Markdown("### Adapters") adapter_l = gr.Dropdown(["None"] + clip_l_opts, value="t5-vit-l-14-dual_shunt_caption.safetensors", label="CLIP-L Adapter") adapter_g = gr.Dropdown(["None"] + clip_g_opts, value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors", label="CLIP-G Adapter") gr.Markdown("### Adapter Controls") strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength") delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Δ scale") sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="σ scale") gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale") noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise") gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob") use_anchor = gr.Checkbox(True, label="Use anchor mix") gr.Markdown("### Generation") with gr.Row(): steps = gr.Slider(1, 50, 20, 1, label="Steps") cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG") scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler") with gr.Row(): width = gr.Slider(512, 1536, 1024, 64, label="Width") height = gr.Slider(512, 1536, 1024, 64, label="Height") seed = gr.Number(-1, label="Seed (-1 → random)", precision=0) run_btn = gr.Button("🚀 Generate", variant="primary") with gr.Column(scale=1): out_img = gr.Image(label="Result", height=400) gr.Markdown("### Diagnostics") delta_l = gr.Image(label="Δ L", height=180) gate_l = gr.Image(label="Gate L", height=180) delta_g = gr.Image(label="Δ G", height=180) gate_g = gr.Image(label="Gate G", height=180) stats_l = gr.Textbox(label="Stats L", interactive=False) stats_g = gr.Textbox(label="Stats G", interactive=False) def _run(*args): pl, npl = args[0], args[1] al, ag = (None if v == "None" else v for v in args[2:4]) return infer(pl, npl, al, ag, *args[4:]) run_btn.click( fn=_run, inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale, sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler, width, height, seed], outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] ) return demo if __name__ == "__main__": create_interface().launch()