AbstractPhil
local project created to properly edit and debug
d3479d5
raw
history blame
11.5 kB
# app.py ────────────────────────────────────────────────────────────────
import io, warnings, numpy as np, matplotlib.pyplot as plt
from typing import Dict, List, Optional
from PIL import Image
from pathlib import Path
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput
from configs import ShuntUtil
warnings.filterwarnings("ignore")
# ─── GLOBALS ─────────────────────────────────────────────────────────────
dtype = torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_t5_tok: Optional[T5Tokenizer] = None
_t5_mod: Optional[T5EncoderModel] = None
_pipe: Optional[StableDiffusionXLPipeline] = None
SCHEDULERS = {
"DPM++ 2M": DPMSolverMultistepScheduler,
"DDIM": DDIMScheduler,
"Euler": EulerDiscreteScheduler,
}
clip_l_shunts = ShuntUtil.get_shunts_by_clip_type("clip_l")
clip_g_shunts = ShuntUtil.get_shunts_by_clip_type("clip_g")
clip_l_opts = ["None"] + [s.name for s in clip_l_shunts]
clip_g_opts = ["None"] + [s.name for s in clip_g_shunts]
# ─── INIT ───────────────────────────────────────────────────────────────
def _init_t5():
global _t5_tok, _t5_mod
if _t5_tok is None:
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
def _init_pipe():
global _pipe
if _pipe is None:
_pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype, variant="fp16", use_safetensors=True
).to(device)
_pipe.enable_xformers_memory_efficient_attention()
# ─── UTILITY ────────────────────────────────────────────────────────────
def load_adapter_by_name(name: str, device: torch.device) -> TwoStreamShuntAdapter:
shunt = ShuntUtil.get_shunt_by_name(name)
assert shunt, f"Shunt '{name}' not found."
path = hf_hub_download(repo_id=shunt.repo, filename=shunt.file)
model = TwoStreamShuntAdapter(shunt.config).eval()
model.load_state_dict(load_file(path))
return model.to(device)
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
if isinstance(mat, torch.Tensor):
mat = mat.detach().cpu().numpy()
if mat.ndim == 1:
mat = mat[None, :]
elif mat.ndim >= 3:
mat = mat.mean(axis=0)
plt.figure(figsize=(7, 3.3), dpi=110)
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
plt.title(title, fontsize=10)
plt.colorbar(shrink=0.7)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
plt.close(); buf.seek(0)
return np.array(Image.open(buf))
def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
tok_l = pipe.tokenizer(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
ntok_l = pipe.tokenizer(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
ntok_g = pipe.tokenizer_2(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
with torch.no_grad():
clip_l = pipe.text_encoder(tok_l)[0]
neg_clip_l = pipe.text_encoder(ntok_l)[0]
g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False)
clip_g, pl = g_out[1], g_out[0]
ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
neg_clip_g, npl = ng_out[1], ng_out[0]
return {"clip_l": clip_l, "clip_g": clip_g, "neg_l": neg_clip_l, "neg_g": neg_clip_g, "pooled": pl, "neg_pooled": npl}
# ─── INFERENCE ───────────────────────────────────────────────────────────
def infer(prompt: str, negative_prompt: str,
adapter_l_name: str, adapter_g_name: str,
strength: float, delta_scale: float, sigma_scale: float,
gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool,
steps: int, cfg_scale: float, scheduler_name: str,
width: int, height: int, seed: int):
torch.cuda.empty_cache()
_init_t5(); _init_pipe()
if scheduler_name in SCHEDULERS:
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
generator = (torch.Generator(device=device).manual_seed(seed) if seed != -1 else None)
cfg_shift = ShiftConfig(
prompt=prompt,
seed=seed,
strength=strength,
delta_scale=delta_scale,
sigma_scale=sigma_scale,
gate_probability=gate_prob,
noise_injection=noise,
use_anchor=use_anchor,
guidance_scale=gpred_scale,
)
t5_seq = ConditioningShifter.extract_encoder_embeddings(
{"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
device, cfg_shift
)
embeds = encode_prompt_xl(_pipe, prompt, negative_prompt)
outputs: List[AdapterOutput] = []
if adapter_l_name and adapter_l_name != "None":
ada_l = load_adapter_by_name(adapter_l_name, device)
outputs.append(ConditioningShifter.run_adapter(
ada_l, t5_seq, embeds["clip_l"],
cfg_shift.guidance_scale, "clip_l", (0, 768)))
if adapter_g_name and adapter_g_name != "None":
ada_g = load_adapter_by_name(adapter_g_name, device)
outputs.append(ConditioningShifter.run_adapter(
ada_g, t5_seq, embeds["clip_g"],
cfg_shift.guidance_scale, "clip_g", (768, 2048)))
clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
delta_viz = {"clip_l": torch.zeros_like(clip_l_mod), "clip_g": torch.zeros_like(clip_g_mod)}
gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]), "clip_g": torch.zeros_like(clip_g_mod[..., :1])}
for out in outputs:
target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift)
if out.adapter_type == "clip_l": clip_l_mod = mod
else: clip_g_mod = mod
delta_viz[out.adapter_type] = out.delta.detach()
gate_viz[out.adapter_type] = out.gate.detach()
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
image = _pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=neg_embeds,
pooled_prompt_embeds=embeds["pooled"],
negative_pooled_prompt_embeds=embeds["neg_pooled"],
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width, height=height, generator=generator
).images[0]
delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Ξ” CLIP-L")
gate_l_img = plot_heat(gate_viz["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L")
delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Ξ” CLIP-G")
gate_g_img = plot_heat(gate_viz["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G")
stats_l = (f"Ο„Μ„_L = {outputs[0].tau.mean().item():.3f}" if outputs and outputs[0].adapter_type == "clip_l" else "-")
stats_g = (f"Ο„Μ„_G = {outputs[-1].tau.mean().item():.3f}" if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-")
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
# ─── GRADIO UI ───────────────────────────────────────────────────────────
def create_interface():
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Prompts")
prompt = gr.Textbox(label="Prompt", lines=3)
negative = gr.Textbox(label="Negative", lines=2)
gr.Markdown("### Adapters")
adapter_l = gr.Dropdown(clip_l_opts, value=clip_l_opts[1], label="CLIP-L Adapter")
adapter_g = gr.Dropdown(clip_g_opts, value=clip_g_opts[1], label="CLIP-G Adapter")
gr.Markdown("### Adapter Controls")
strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength")
delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ” scale")
sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Οƒ scale")
gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale")
noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
use_anchor = gr.Checkbox(True, label="Use anchor mix")
gr.Markdown("### Generation")
with gr.Row():
steps = gr.Slider(1, 50, 20, 1, label="Steps")
cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler")
with gr.Row():
width = gr.Slider(512, 1536, 1024, 64, label="Width")
height = gr.Slider(512, 1536, 1024, 64, label="Height")
seed = gr.Number(-1, label="Seed (-1 β†’ random)", precision=0)
run_btn = gr.Button("πŸš€ Generate", variant="primary")
with gr.Column(scale=1):
out_img = gr.Image(label="Result", height=400)
gr.Markdown("### Diagnostics")
delta_l = gr.Image(label="Ξ” L", height=180)
gate_l = gr.Image(label="Gate L", height=180)
delta_g = gr.Image(label="Ξ” G", height=180)
gate_g = gr.Image(label="Gate G", height=180)
stats_l = gr.Textbox(label="Stats L", interactive=False)
stats_g = gr.Textbox(label="Stats G", interactive=False)
run_btn.click(
fn=infer,
inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale,
sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps,
cfg_scale, scheduler, width, height, seed],
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
)
return demo
if __name__ == "__main__":
create_interface().launch()