Spaces:
Runtime error
Runtime error
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import io, warnings, numpy as np, matplotlib.pyplot as plt | |
from typing import Dict, List, Optional | |
from PIL import Image | |
from pathlib import Path | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from transformers import T5Tokenizer, T5EncoderModel | |
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from two_stream_shunt_adapter import TwoStreamShuntAdapter | |
from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput | |
from configs import ShuntUtil | |
warnings.filterwarnings("ignore") | |
# βββ GLOBALS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
dtype = torch.float16 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
_t5_tok: Optional[T5Tokenizer] = None | |
_t5_mod: Optional[T5EncoderModel] = None | |
_pipe: Optional[StableDiffusionXLPipeline] = None | |
SCHEDULERS = { | |
"DPM++ 2M": DPMSolverMultistepScheduler, | |
"DDIM": DDIMScheduler, | |
"Euler": EulerDiscreteScheduler, | |
} | |
clip_l_shunts = ShuntUtil.get_shunts_by_clip_type("clip_l") | |
clip_g_shunts = ShuntUtil.get_shunts_by_clip_type("clip_g") | |
clip_l_opts = ["None"] + [s.name for s in clip_l_shunts] | |
clip_g_opts = ["None"] + [s.name for s in clip_g_shunts] | |
# βββ INIT βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def _init_t5(): | |
global _t5_tok, _t5_mod | |
if _t5_tok is None: | |
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() | |
def _init_pipe(): | |
global _pipe | |
if _pipe is None: | |
_pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, variant="fp16", use_safetensors=True | |
).to(device) | |
_pipe.enable_xformers_memory_efficient_attention() | |
# βββ UTILITY ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def load_adapter_by_name(name: str, device: torch.device) -> TwoStreamShuntAdapter: | |
shunt = ShuntUtil.get_shunt_by_name(name) | |
assert shunt, f"Shunt '{name}' not found." | |
path = hf_hub_download(repo_id=shunt.repo, filename=shunt.file) | |
model = TwoStreamShuntAdapter(shunt.config).eval() | |
model.load_state_dict(load_file(path)) | |
return model.to(device) | |
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray: | |
if isinstance(mat, torch.Tensor): | |
mat = mat.detach().cpu().numpy() | |
if mat.ndim == 1: | |
mat = mat[None, :] | |
elif mat.ndim >= 3: | |
mat = mat.mean(axis=0) | |
plt.figure(figsize=(7, 3.3), dpi=110) | |
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper") | |
plt.title(title, fontsize=10) | |
plt.colorbar(shrink=0.7) | |
plt.tight_layout() | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png", bbox_inches="tight") | |
plt.close(); buf.seek(0) | |
return np.array(Image.open(buf)) | |
def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]: | |
tok_l = pipe.tokenizer(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) | |
tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) | |
ntok_l = pipe.tokenizer(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) | |
ntok_g = pipe.tokenizer_2(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device) | |
with torch.no_grad(): | |
clip_l = pipe.text_encoder(tok_l)[0] | |
neg_clip_l = pipe.text_encoder(ntok_l)[0] | |
g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False) | |
clip_g, pl = g_out[1], g_out[0] | |
ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False) | |
neg_clip_g, npl = ng_out[1], ng_out[0] | |
return {"clip_l": clip_l, "clip_g": clip_g, "neg_l": neg_clip_l, "neg_g": neg_clip_g, "pooled": pl, "neg_pooled": npl} | |
# βββ INFERENCE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def infer(prompt: str, negative_prompt: str, | |
adapter_l_name: str, adapter_g_name: str, | |
strength: float, delta_scale: float, sigma_scale: float, | |
gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool, | |
steps: int, cfg_scale: float, scheduler_name: str, | |
width: int, height: int, seed: int): | |
torch.cuda.empty_cache() | |
_init_t5(); _init_pipe() | |
if scheduler_name in SCHEDULERS: | |
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config) | |
generator = (torch.Generator(device=device).manual_seed(seed) if seed != -1 else None) | |
cfg_shift = ShiftConfig( | |
prompt=prompt, | |
seed=seed, | |
strength=strength, | |
delta_scale=delta_scale, | |
sigma_scale=sigma_scale, | |
gate_probability=gate_prob, | |
noise_injection=noise, | |
use_anchor=use_anchor, | |
guidance_scale=gpred_scale, | |
) | |
t5_seq = ConditioningShifter.extract_encoder_embeddings( | |
{"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}}, | |
device, cfg_shift | |
) | |
embeds = encode_prompt_xl(_pipe, prompt, negative_prompt) | |
outputs: List[AdapterOutput] = [] | |
if adapter_l_name and adapter_l_name != "None": | |
ada_l = load_adapter_by_name(adapter_l_name, device) | |
outputs.append(ConditioningShifter.run_adapter( | |
ada_l, t5_seq, embeds["clip_l"], | |
cfg_shift.guidance_scale, "clip_l", (0, 768))) | |
if adapter_g_name and adapter_g_name != "None": | |
ada_g = load_adapter_by_name(adapter_g_name, device) | |
outputs.append(ConditioningShifter.run_adapter( | |
ada_g, t5_seq, embeds["clip_g"], | |
cfg_shift.guidance_scale, "clip_g", (768, 2048))) | |
clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"] | |
delta_viz = {"clip_l": torch.zeros_like(clip_l_mod), "clip_g": torch.zeros_like(clip_g_mod)} | |
gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]), "clip_g": torch.zeros_like(clip_g_mod[..., :1])} | |
for out in outputs: | |
target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod | |
mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift) | |
if out.adapter_type == "clip_l": clip_l_mod = mod | |
else: clip_g_mod = mod | |
delta_viz[out.adapter_type] = out.delta.detach() | |
gate_viz[out.adapter_type] = out.gate.detach() | |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1) | |
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1) | |
image = _pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=neg_embeds, | |
pooled_prompt_embeds=embeds["pooled"], | |
negative_pooled_prompt_embeds=embeds["neg_pooled"], | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, height=height, generator=generator | |
).images[0] | |
delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Ξ CLIP-L") | |
gate_l_img = plot_heat(gate_viz["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L") | |
delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Ξ CLIP-G") | |
gate_g_img = plot_heat(gate_viz["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G") | |
stats_l = (f"ΟΜ_L = {outputs[0].tau.mean().item():.3f}" if outputs and outputs[0].adapter_type == "clip_l" else "-") | |
stats_g = (f"ΟΜ_G = {outputs[-1].tau.mean().item():.3f}" if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-") | |
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g | |
# βββ GRADIO UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def create_interface(): | |
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ SDXL Dual-Shunt Tester") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Prompts") | |
prompt = gr.Textbox(label="Prompt", lines=3) | |
negative = gr.Textbox(label="Negative", lines=2) | |
gr.Markdown("### Adapters") | |
adapter_l = gr.Dropdown(clip_l_opts, value=clip_l_opts[1], label="CLIP-L Adapter") | |
adapter_g = gr.Dropdown(clip_g_opts, value=clip_g_opts[1], label="CLIP-G Adapter") | |
gr.Markdown("### Adapter Controls") | |
strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength") | |
delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ scale") | |
sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Ο scale") | |
gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale") | |
noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise") | |
gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob") | |
use_anchor = gr.Checkbox(True, label="Use anchor mix") | |
gr.Markdown("### Generation") | |
with gr.Row(): | |
steps = gr.Slider(1, 50, 20, 1, label="Steps") | |
cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG") | |
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler") | |
with gr.Row(): | |
width = gr.Slider(512, 1536, 1024, 64, label="Width") | |
height = gr.Slider(512, 1536, 1024, 64, label="Height") | |
seed = gr.Number(-1, label="Seed (-1 β random)", precision=0) | |
run_btn = gr.Button("π Generate", variant="primary") | |
with gr.Column(scale=1): | |
out_img = gr.Image(label="Result", height=400) | |
gr.Markdown("### Diagnostics") | |
delta_l = gr.Image(label="Ξ L", height=180) | |
gate_l = gr.Image(label="Gate L", height=180) | |
delta_g = gr.Image(label="Ξ G", height=180) | |
gate_g = gr.Image(label="Gate G", height=180) | |
stats_l = gr.Textbox(label="Stats L", interactive=False) | |
stats_g = gr.Textbox(label="Stats G", interactive=False) | |
run_btn.click( | |
fn=infer, | |
inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale, | |
sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, | |
cfg_scale, scheduler, width, height, seed], | |
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g] | |
) | |
return demo | |
if __name__ == "__main__": | |
create_interface().launch() | |