# app.py ──────────────────────────────────────────────────────────────── import io, warnings, numpy as np, matplotlib.pyplot as plt from typing import Dict, List, Optional from PIL import Image from pathlib import Path import gradio as gr import torch import torch.nn.functional as F from transformers import T5Tokenizer, T5EncoderModel from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler from huggingface_hub import hf_hub_download from safetensors.torch import load_file from two_stream_shunt_adapter import TwoStreamShuntAdapter from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput from configs import ShuntUtil warnings.filterwarnings("ignore") # ─── GLOBALS ───────────────────────────────────────────────────────────── dtype = torch.float16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _t5_tok: Optional[T5Tokenizer] = None _t5_mod: Optional[T5EncoderModel] = None _pipe: Optional[StableDiffusionXLPipeline] = None SCHEDULERS = { "DPM++ 2M": DPMSolverMultistepScheduler, "DDIM": DDIMScheduler, "Euler": EulerDiscreteScheduler, } clip_l_shunts = ShuntUtil.get_shunts_by_clip_type("clip_l") clip_g_shunts = ShuntUtil.get_shunts_by_clip_type("clip_g") clip_l_opts = ["None"] + [s.name for s in clip_l_shunts] clip_g_opts = ["None"] + [s.name for s in clip_g_shunts] # ─── INIT ─────────────────────────────────────────────────────────────── def _init_t5(): global _t5_tok, _t5_mod if _t5_tok is None: _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() def _init_pipe(): global _pipe if _pipe is None: _pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16", use_safetensors=True ).to(device) _pipe.enable_xformers_memory_efficient_attention() # ─── UTILITY ──────────────────────────────────────────────────────────── def load_adapter_by_name(name: str, device: torch.device) -> TwoStreamShuntAdapter: shunt = ShuntUtil.get_shunt_by_name(name) assert shunt, f"Shunt '{name}' not found." path = hf_hub_download(repo_id=shunt.repo, filename=shunt.file) model = TwoStreamShuntAdapter(shunt.config).eval() model.load_state_dict(load_file(path)) return model.to(device) def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray: if isinstance(mat, torch.Tensor): mat = mat.detach().cpu().numpy() if mat.ndim == 1: mat = mat[None, :] elif mat.ndim >= 3: mat = mat.mean(axis=0) plt.figure(figsize=(7, 3.3), dpi=110) plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper") plt.title(title, fontsize=10) plt.colorbar(shrink=0.7) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight") plt.close(); buf.seek(0) return np.array(Image.open(buf)) def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]: tok_l = pipe.tokenizer(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) ntok_l = pipe.tokenizer(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) ntok_g = pipe.tokenizer_2(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) with torch.no_grad(): clip_l = pipe.text_encoder(tok_l)[0] neg_clip_l = pipe.text_encoder(ntok_l)[0] g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False) clip_g, pl = g_out[1], g_out[0] ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False) neg_clip_g, npl = ng_out[1], ng_out[0] return {"clip_l": clip_l, "clip_g": clip_g, "neg_l": neg_clip_l, "neg_g": neg_clip_g, "pooled": pl, "neg_pooled": npl} # ─── INFERENCE ─────────────────────────────────────────────────────────── def infer(prompt: str, negative_prompt: str, adapter_l_name: str, adapter_g_name: str, strength: float, delta_scale: float, sigma_scale: float, gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool, steps: int, cfg_scale: float, scheduler_name: str, width: int, height: int, seed: int): torch.cuda.empty_cache() _init_t5(); _init_pipe() if scheduler_name in SCHEDULERS: _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config) generator = (torch.Generator(device=device).manual_seed(seed) if seed != -1 else None) cfg_shift = ShiftConfig( prompt=prompt, seed=seed, strength=strength, delta_scale=delta_scale, sigma_scale=sigma_scale, gate_probability=gate_prob, noise_injection=noise, use_anchor=use_anchor, guidance_scale=gpred_scale, ) t5_seq = ConditioningShifter.extract_encoder_embeddings( {"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}}, device, cfg_shift ) embeds = encode_prompt_xl(_pipe, prompt, negative_prompt) outputs: List[AdapterOutput] = [] if adapter_l_name and adapter_l_name != "None": ada_l = load_adapter_by_name(adapter_l_name, device) outputs.append(ConditioningShifter.run_adapter( ada_l, t5_seq, embeds["clip_l"], cfg_shift.guidance_scale, "clip_l", (0, 768))) if adapter_g_name and adapter_g_name != "None": ada_g = load_adapter_by_name(adapter_g_name, device) outputs.append(ConditioningShifter.run_adapter( ada_g, t5_seq, embeds["clip_g"], cfg_shift.guidance_scale, "clip_g", (768, 2048))) clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"] delta_viz = {"clip_l": torch.zeros_like(clip_l_mod), "clip_g": torch.zeros_like(clip_g_mod)} gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]), "clip_g": torch.zeros_like(clip_g_mod[..., :1])} for out in outputs: target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift) if out.adapter_type == "clip_l": clip_l_mod = mod else: clip_g_mod = mod delta_viz[out.adapter_type] = out.delta.detach() gate_viz[out.adapter_type] = out.gate.detach() prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1) image = _pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_embeds, pooled_prompt_embeds=embeds["pooled"], negative_pooled_prompt_embeds=embeds["neg_pooled"], num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator ).images[0] delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Δ CLIP-L") gate_l_img = plot_heat(gate_viz["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L") delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Δ CLIP-G") gate_g_img = plot_heat(gate_viz["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G") stats_l = (f"τ̄_L = {outputs[0].tau.mean().item():.3f}" if outputs and outputs[0].adapter_type == "clip_l" else "-") stats_g = (f"τ̄_G = {outputs[-1].tau.mean().item():.3f}" if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-") return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g # ─── GRADIO UI ─────────────────────────────────────────────────────────── def create_interface(): with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 SDXL Dual-Shunt Tester") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Prompts") prompt = gr.Textbox(label="Prompt", lines=3) negative = gr.Textbox(label="Negative", lines=2) gr.Markdown("### Adapters") adapter_l = gr.Dropdown(clip_l_opts, value=clip_l_opts[1], label="CLIP-L Adapter") adapter_g = gr.Dropdown(clip_g_opts, value=clip_g_opts[1], label="CLIP-G Adapter") gr.Markdown("### Adapter Controls") strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength") delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Δ scale") sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="σ scale") gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale") noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise") gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob") use_anchor = gr.Checkbox(True, label="Use anchor mix") gr.Markdown("### Generation") with gr.Row(): steps = gr.Slider(1, 50, 20, 1, label="Steps") cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG") scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler") with gr.Row(): width = gr.Slider(512, 1536, 1024, 64, label="Width") height = gr.Slider(512, 1536, 1024, 64, label="Height") seed = gr.Number(-1, label="Seed (-1 → random)", precision=0) run_btn = gr.Button("🚀 Generate", variant="primary") with gr.Column(scale=1): out_img = gr.Image(label="Result", height=400) gr.Markdown("### Diagnostics") delta_l = gr.Image(label="Δ L", height=180) gate_l = gr.Image(label="Gate L", height=180) delta_g = gr.Image(label="Δ G", height=180) gate_g = gr.Image(label="Gate G", height=180) stats_l = gr.Textbox(label="Stats L", interactive=False) stats_g = gr.Textbox(label="Stats G", interactive=False) run_btn.click( fn=infer, inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale, sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale, scheduler, width, height, seed], outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] ) return demo if __name__ == "__main__": create_interface().launch()