File size: 11,537 Bytes
19e2e87
 
d3479d5
 
e543e33
 
403ae01
d3479d5
 
 
ca066a9
d3479d5
ca066a9
e543e33
 
620a643
19e2e87
d3479d5
e543e33
 
 
d3479d5
 
19e2e87
ca066a9
d3479d5
 
 
ca066a9
1e5ce4d
19e2e87
d3479d5
 
1e5ce4d
 
d3479d5
 
 
 
e543e33
d3479d5
e543e33
 
 
 
d3479d5
e543e33
 
 
 
 
 
19e2e87
e543e33
 
 
d3479d5
 
 
 
 
 
19e2e87
5759aab
ca066a9
e543e33
b6b9cb1
 
e543e33
 
19e2e87
e543e33
 
19e2e87
e543e33
19e2e87
e543e33
b6b9cb1
e543e33
1e5ce4d
19e2e87
 
e543e33
 
19e2e87
d3479d5
 
 
 
e543e33
1e5ce4d
d3479d5
 
 
 
 
19e2e87
e543e33
d3479d5
e543e33
19e2e87
 
d3479d5
19e2e87
 
 
 
e543e33
 
 
 
 
 
 
d3479d5
e543e33
535b292
d3479d5
 
 
 
 
 
 
 
 
535b292
19e2e87
535b292
 
 
 
19e2e87
 
 
d3479d5
 
 
535b292
 
 
19e2e87
d3479d5
 
535b292
 
 
19e2e87
535b292
d3479d5
 
19e2e87
535b292
19e2e87
d3479d5
 
 
19e2e87
d3479d5
535b292
5759aab
d3479d5
e543e33
 
d3479d5
 
 
 
 
 
 
c22af2e
e543e33
19e2e87
d3479d5
19e2e87
d3479d5
19e2e87
d3479d5
 
e543e33
 
 
d3479d5
5759aab
e543e33
 
 
5759aab
 
e543e33
d3479d5
 
e543e33
 
d3479d5
 
e543e33
 
19e2e87
 
 
 
 
 
 
e543e33
 
1e5ce4d
19e2e87
 
d3479d5
1e5ce4d
19e2e87
 
 
e543e33
19e2e87
e543e33
5759aab
19e2e87
 
 
 
 
 
 
 
e543e33
19e2e87
d3479d5
19e2e87
 
 
 
5759aab
 
403ae01
 
e543e33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# app.py ────────────────────────────────────────────────────────────────
import io, warnings, numpy as np, matplotlib.pyplot as plt
from typing import Dict, List, Optional
from PIL import Image
from pathlib import Path

import gradio as gr
import torch
import torch.nn.functional as F

from transformers import T5Tokenizer, T5EncoderModel
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

from two_stream_shunt_adapter import TwoStreamShuntAdapter
from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput
from configs import ShuntUtil

warnings.filterwarnings("ignore")

# ─── GLOBALS ─────────────────────────────────────────────────────────────
dtype = torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

_t5_tok: Optional[T5Tokenizer] = None
_t5_mod: Optional[T5EncoderModel] = None
_pipe: Optional[StableDiffusionXLPipeline] = None

SCHEDULERS = {
    "DPM++ 2M": DPMSolverMultistepScheduler,
    "DDIM": DDIMScheduler,
    "Euler": EulerDiscreteScheduler,
}

clip_l_shunts = ShuntUtil.get_shunts_by_clip_type("clip_l")
clip_g_shunts = ShuntUtil.get_shunts_by_clip_type("clip_g")
clip_l_opts = ["None"] + [s.name for s in clip_l_shunts]
clip_g_opts = ["None"] + [s.name for s in clip_g_shunts]

# ─── INIT ───────────────────────────────────────────────────────────────
def _init_t5():
    global _t5_tok, _t5_mod
    if _t5_tok is None:
        _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
        _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()

def _init_pipe():
    global _pipe
    if _pipe is None:
        _pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=dtype, variant="fp16", use_safetensors=True
        ).to(device)
        _pipe.enable_xformers_memory_efficient_attention()

# ─── UTILITY ────────────────────────────────────────────────────────────
def load_adapter_by_name(name: str, device: torch.device) -> TwoStreamShuntAdapter:
    shunt = ShuntUtil.get_shunt_by_name(name)
    assert shunt, f"Shunt '{name}' not found."
    path = hf_hub_download(repo_id=shunt.repo, filename=shunt.file)
    model = TwoStreamShuntAdapter(shunt.config).eval()
    model.load_state_dict(load_file(path))
    return model.to(device)

def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
    if isinstance(mat, torch.Tensor):
        mat = mat.detach().cpu().numpy()
    if mat.ndim == 1:
        mat = mat[None, :]
    elif mat.ndim >= 3:
        mat = mat.mean(axis=0)

    plt.figure(figsize=(7, 3.3), dpi=110)
    plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
    plt.title(title, fontsize=10)
    plt.colorbar(shrink=0.7)
    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight")
    plt.close(); buf.seek(0)
    return np.array(Image.open(buf))

def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
    tok_l = pipe.tokenizer(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
    tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
    ntok_l = pipe.tokenizer(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
    ntok_g = pipe.tokenizer_2(negative, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)

    with torch.no_grad():
        clip_l = pipe.text_encoder(tok_l)[0]
        neg_clip_l = pipe.text_encoder(ntok_l)[0]
        g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False)
        clip_g, pl = g_out[1], g_out[0]
        ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
        neg_clip_g, npl = ng_out[1], ng_out[0]

    return {"clip_l": clip_l, "clip_g": clip_g, "neg_l": neg_clip_l, "neg_g": neg_clip_g, "pooled": pl, "neg_pooled": npl}

# ─── INFERENCE ───────────────────────────────────────────────────────────
def infer(prompt: str, negative_prompt: str,
          adapter_l_name: str, adapter_g_name: str,
          strength: float, delta_scale: float, sigma_scale: float,
          gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool,
          steps: int, cfg_scale: float, scheduler_name: str,
          width: int, height: int, seed: int):

    torch.cuda.empty_cache()
    _init_t5(); _init_pipe()

    if scheduler_name in SCHEDULERS:
        _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)

    generator = (torch.Generator(device=device).manual_seed(seed) if seed != -1 else None)

    cfg_shift = ShiftConfig(
        prompt=prompt,
        seed=seed,
        strength=strength,
        delta_scale=delta_scale,
        sigma_scale=sigma_scale,
        gate_probability=gate_prob,
        noise_injection=noise,
        use_anchor=use_anchor,
        guidance_scale=gpred_scale,
    )

    t5_seq = ConditioningShifter.extract_encoder_embeddings(
        {"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
        device, cfg_shift
    )

    embeds = encode_prompt_xl(_pipe, prompt, negative_prompt)
    outputs: List[AdapterOutput] = []

    if adapter_l_name and adapter_l_name != "None":
        ada_l = load_adapter_by_name(adapter_l_name, device)
        outputs.append(ConditioningShifter.run_adapter(
            ada_l, t5_seq, embeds["clip_l"],
            cfg_shift.guidance_scale, "clip_l", (0, 768)))

    if adapter_g_name and adapter_g_name != "None":
        ada_g = load_adapter_by_name(adapter_g_name, device)
        outputs.append(ConditioningShifter.run_adapter(
            ada_g, t5_seq, embeds["clip_g"],
            cfg_shift.guidance_scale, "clip_g", (768, 2048)))

    clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
    delta_viz = {"clip_l": torch.zeros_like(clip_l_mod), "clip_g": torch.zeros_like(clip_g_mod)}
    gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]), "clip_g": torch.zeros_like(clip_g_mod[..., :1])}

    for out in outputs:
        target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
        mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift)
        if out.adapter_type == "clip_l": clip_l_mod = mod
        else: clip_g_mod = mod
        delta_viz[out.adapter_type] = out.delta.detach()
        gate_viz[out.adapter_type] = out.gate.detach()

    prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
    neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)

    image = _pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=neg_embeds,
        pooled_prompt_embeds=embeds["pooled"],
        negative_pooled_prompt_embeds=embeds["neg_pooled"],
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width, height=height, generator=generator
    ).images[0]

    delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Ξ” CLIP-L")
    gate_l_img  = plot_heat(gate_viz["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L")
    delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Ξ” CLIP-G")
    gate_g_img  = plot_heat(gate_viz["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G")

    stats_l = (f"Ο„Μ„_L = {outputs[0].tau.mean().item():.3f}" if outputs and outputs[0].adapter_type == "clip_l" else "-")
    stats_g = (f"Ο„Μ„_G = {outputs[-1].tau.mean().item():.3f}" if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-")

    return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g

# ─── GRADIO UI ───────────────────────────────────────────────────────────
def create_interface():
    with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Prompts")
                prompt   = gr.Textbox(label="Prompt", lines=3)
                negative = gr.Textbox(label="Negative", lines=2)

                gr.Markdown("### Adapters")
                adapter_l = gr.Dropdown(clip_l_opts, value=clip_l_opts[1], label="CLIP-L Adapter")
                adapter_g = gr.Dropdown(clip_g_opts, value=clip_g_opts[1], label="CLIP-G Adapter")

                gr.Markdown("### Adapter Controls")
                strength    = gr.Slider(0, 10, 4.0, 0.05, label="Strength")
                delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ” scale")
                sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Οƒ scale")
                gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale")
                noise       = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
                gate_prob   = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
                use_anchor  = gr.Checkbox(True, label="Use anchor mix")

                gr.Markdown("### Generation")
                with gr.Row():
                    steps     = gr.Slider(1, 50, 20, 1, label="Steps")
                    cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
                scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler")
                with gr.Row():
                    width  = gr.Slider(512, 1536, 1024, 64, label="Width")
                    height = gr.Slider(512, 1536, 1024, 64, label="Height")
                seed = gr.Number(-1, label="Seed (-1 β†’ random)", precision=0)

                run_btn = gr.Button("πŸš€ Generate", variant="primary")

            with gr.Column(scale=1):
                out_img  = gr.Image(label="Result", height=400)
                gr.Markdown("### Diagnostics")
                delta_l  = gr.Image(label="Ξ” L", height=180)
                gate_l   = gr.Image(label="Gate L", height=180)
                delta_g  = gr.Image(label="Ξ” G", height=180)
                gate_g   = gr.Image(label="Gate G", height=180)
                stats_l  = gr.Textbox(label="Stats L", interactive=False)
                stats_g  = gr.Textbox(label="Stats G", interactive=False)

        run_btn.click(
            fn=infer,
            inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale,
                    sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps,
                    cfg_scale, scheduler, width, height, seed],
            outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
        )
    return demo

if __name__ == "__main__":
    create_interface().launch()