Spaces:
Runtime error
Runtime error
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import io, warnings, numpy as np, matplotlib.pyplot as plt | |
from pathlib import Path | |
from typing import Dict, List, Optional, Tuple | |
import gradio as gr | |
import torch, torch.nn.functional as F | |
from PIL import Image | |
from transformers import T5Tokenizer, T5EncoderModel | |
from diffusers import ( | |
StableDiffusionXLPipeline, | |
DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, | |
) | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
# local modules | |
from two_stream_shunt_adapter import TwoStreamShuntAdapter | |
from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput | |
from embedding_manager import get_bank | |
from configs import T5_SHUNT_REPOS | |
warnings.filterwarnings("ignore") | |
# βββ GLOBALS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
dtype = torch.float16 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
_bank = get_bank() # singleton β optional caching | |
_t5_tok: Optional[T5Tokenizer] = None | |
_t5_mod: Optional[T5EncoderModel] = None | |
_pipe : Optional[StableDiffusionXLPipeline] = None | |
SCHEDULERS = { | |
"DPM++ 2M": DPMSolverMultistepScheduler, | |
"DDIM": DDIMScheduler, | |
"Euler": EulerDiscreteScheduler, | |
} | |
# adapter-meta from configs.py | |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] | |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] | |
repo_l, conf_l = T5_SHUNT_REPOS["clip_l"]["repo"], T5_SHUNT_REPOS["clip_l"]["config"] | |
repo_g, conf_g = T5_SHUNT_REPOS["clip_g"]["repo"], T5_SHUNT_REPOS["clip_g"]["config"] | |
# βββ INITIALISERS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def _init_t5(): | |
global _t5_tok, _t5_mod | |
if _t5_tok is None: | |
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base") \ | |
.to(device).eval() | |
def _init_pipe(): | |
global _pipe | |
if _pipe is None: | |
_pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, variant="fp16", use_safetensors=True | |
).to(device) | |
_pipe.enable_xformers_memory_efficient_attention() | |
# βββ HELPERS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def load_adapter(repo: str, filename: str, cfg: dict, | |
device: torch.device) -> TwoStreamShuntAdapter: | |
path = hf_hub_download(repo_id=repo, filename=filename) | |
model = TwoStreamShuntAdapter(cfg).eval() | |
model.load_state_dict(load_file(path)) | |
return model.to(device) | |
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray: | |
if isinstance(mat, torch.Tensor): | |
mat = mat.detach().cpu().numpy() | |
if mat.ndim == 1: | |
mat = mat[None, :] | |
elif mat.ndim >= 3: | |
mat = mat.mean(axis=0) | |
plt.figure(figsize=(7, 3.3), dpi=110) | |
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper") | |
plt.title(title, fontsize=10) | |
plt.colorbar(shrink=0.7) | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png", bbox_inches="tight") | |
plt.close(); buf.seek(0) | |
return np.array(Image.open(buf)) | |
def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]: | |
tok_l = pipe.tokenizer (prompt, max_length=77, truncation=True, | |
padding="max_length", return_tensors="pt").input_ids.to(device) | |
tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True, | |
padding="max_length", return_tensors="pt").input_ids.to(device) | |
ntok_l = pipe.tokenizer (negative,max_length=77, truncation=True, | |
padding="max_length", return_tensors="pt").input_ids.to(device) | |
ntok_g = pipe.tokenizer_2(negative,max_length=77, truncation=True, | |
padding="max_length", return_tensors="pt").input_ids.to(device) | |
with torch.no_grad(): | |
clip_l = pipe.text_encoder(tok_l)[0] | |
neg_clip_l = pipe.text_encoder(ntok_l)[0] | |
g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False) | |
clip_g, pl = g_out[1], g_out[0] | |
ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False) | |
neg_clip_g, npl = ng_out[1], ng_out[0] | |
return {"clip_l": clip_l, "clip_g": clip_g, | |
"neg_l": neg_clip_l, "neg_g": neg_clip_g, | |
"pooled": pl, "neg_pooled": npl} | |
# βββ INFERENCE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def infer(prompt: str, negative_prompt: str, | |
adapter_l_file: str, adapter_g_file: str, | |
strength: float, delta_scale: float, sigma_scale: float, | |
gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool, | |
steps: int, cfg_scale: float, scheduler_name: str, | |
width: int, height: int, seed: int): | |
torch.cuda.empty_cache() | |
_init_t5(); _init_pipe() | |
if scheduler_name in SCHEDULERS: | |
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config) | |
generator = (torch.Generator(device=device).manual_seed(seed) | |
if seed != -1 else None) | |
# build ShiftConfig (one per request) | |
cfg_shift = ShiftConfig( | |
prompt = prompt, | |
seed = seed, | |
strength = strength, | |
delta_scale = delta_scale, | |
sigma_scale = sigma_scale, | |
gate_probability = gate_prob, | |
noise_injection = noise, | |
use_anchor = use_anchor, | |
guidance_scale = gpred_scale, | |
) | |
# encoder (T5) embeddings | |
t5_seq = ConditioningShifter.extract_encoder_embeddings( | |
{"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}}, | |
device, cfg_shift | |
) | |
# CLIP embeddings | |
embeds = encode_prompt_xl(_pipe, prompt, negative_prompt) | |
# run adapters -------------------------------------------------------- | |
outputs: List[AdapterOutput] = [] | |
if adapter_l_file and adapter_l_file != "None": | |
ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device) | |
outputs.append(ConditioningShifter.run_adapter( | |
ada_l, t5_seq, embeds["clip_l"], | |
cfg_shift.guidance_scale, "clip_l", (0, 768))) | |
if adapter_g_file and adapter_g_file != "None": | |
ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device) | |
outputs.append(ConditioningShifter.run_adapter( | |
ada_g, t5_seq, embeds["clip_g"], | |
cfg_shift.guidance_scale, "clip_g", (768, 2048))) | |
# apply modifications ------------------------------------------------- | |
clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"] | |
delta_viz = {"clip_l": torch.zeros_like(clip_l_mod), | |
"clip_g": torch.zeros_like(clip_g_mod)} | |
gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]), | |
"clip_g": torch.zeros_like(clip_g_mod[..., :1])} | |
for out in outputs: | |
target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod | |
mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift) | |
if out.adapter_type == "clip_l": | |
clip_l_mod = mod | |
else: | |
clip_g_mod = mod | |
delta_viz[out.adapter_type] = out.delta.detach() | |
gate_viz [out.adapter_type] = out.gate.detach() | |
# prepare for SDXL ---------------------------------------------------- | |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) | |
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1) | |
image = _pipe( | |
prompt_embeds = prompt_embeds, | |
negative_prompt_embeds = neg_embeds, | |
pooled_prompt_embeds = embeds["pooled"], | |
negative_pooled_prompt_embeds = embeds["neg_pooled"], | |
num_inference_steps = steps, | |
guidance_scale = cfg_scale, | |
width = width, height = height, generator = generator | |
).images[0] | |
# diagnostics --------------------------------------------------------- | |
delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Ξ CLIP-L") | |
gate_l_img = plot_heat(gate_viz ["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L") | |
delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Ξ CLIP-G") | |
gate_g_img = plot_heat(gate_viz ["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G") | |
stats_l = (f"ΟΜ_L = {outputs[0].tau.mean().item():.3f}" | |
if outputs and outputs[0].adapter_type == "clip_l" else "-") | |
stats_g = (f"ΟΜ_G = {outputs[-1].tau.mean().item():.3f}" | |
if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-") | |
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g | |
# βββ GRADIO UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def create_interface(): | |
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ SDXL Dual-Shunt Tester") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Prompts") | |
prompt = gr.Textbox(label="Prompt", lines=3, | |
value="a futuristic control station with holographic displays") | |
negative = gr.Textbox(label="Negative", lines=2, | |
value="blurry, low quality, distorted") | |
gr.Markdown("### Adapters") | |
adapter_l = gr.Dropdown(["None"] + clip_l_opts, | |
value="t5-vit-l-14-dual_shunt_caption.safetensors", | |
label="CLIP-L Adapter") | |
adapter_g = gr.Dropdown(["None"] + clip_g_opts, | |
value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors", | |
label="CLIP-G Adapter") | |
gr.Markdown("### Adapter Controls") | |
strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength") | |
delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ scale") | |
sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Ο scale") | |
gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale") | |
noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise") | |
gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob") | |
use_anchor = gr.Checkbox(True, label="Use anchor mix") | |
gr.Markdown("### Generation") | |
with gr.Row(): | |
steps = gr.Slider(1, 50, 20, 1, label="Steps") | |
cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG") | |
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), | |
value="DPM++ 2M", label="Scheduler") | |
with gr.Row(): | |
width = gr.Slider(512, 1536, 1024, 64, label="Width") | |
height = gr.Slider(512, 1536, 1024, 64, label="Height") | |
seed = gr.Number(-1, label="Seed (-1 β random)", precision=0) | |
run_btn = gr.Button("π Generate", variant="primary") | |
with gr.Column(scale=1): | |
out_img = gr.Image(label="Result", height=400) | |
gr.Markdown("### Diagnostics") | |
delta_l = gr.Image(label="Ξ L", height=180) | |
gate_l = gr.Image(label="Gate L", height=180) | |
delta_g = gr.Image(label="Ξ G", height=180) | |
gate_g = gr.Image(label="Gate G", height=180) | |
stats_l = gr.Textbox(label="Stats L", interactive=False) | |
stats_g = gr.Textbox(label="Stats G", interactive=False) | |
def _run(*args): | |
pl, npl = args[0], args[1] | |
al, ag = (None if v == "None" else v for v in args[2:4]) | |
return infer(pl, npl, al, ag, *args[4:]) | |
run_btn.click( | |
fn=_run, | |
inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale, | |
sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, | |
cfg_scale, scheduler, width, height, seed], | |
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] | |
) | |
return demo | |
if __name__ == "__main__": | |
create_interface().launch() | |