AbstractPhil commited on
Commit
19e2e87
Β·
verified Β·
1 Parent(s): d138161

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -186
app.py CHANGED
@@ -1,69 +1,57 @@
1
- # app.py ────────────────────────────────────────────────────────────────────
2
- import io, os, json, math, random, warnings, gc, functools, hashlib
3
  from pathlib import Path
4
- from typing import Dict, List, Optional
5
 
6
  import gradio as gr
7
- import numpy as np
8
- import matplotlib.pyplot as plt
9
  from PIL import Image
10
-
11
- import torch
12
- import torch.nn.functional as F
13
  from transformers import T5Tokenizer, T5EncoderModel
14
  from diffusers import (
15
  StableDiffusionXLPipeline,
16
- DDIMScheduler,
17
- EulerDiscreteScheduler,
18
- DPMSolverMultistepScheduler,
19
  )
20
  from huggingface_hub import hf_hub_download
21
  from safetensors.torch import load_file
22
 
23
- # -------------------------------------------------------------------------
24
  # local modules
25
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
26
- from .conditioning_shifter import ConditioningShifter
 
27
  from configs import T5_SHUNT_REPOS
28
- from embedding_manager import get_bank # ← NEW
29
 
30
  warnings.filterwarnings("ignore")
31
 
 
 
 
32
 
33
- # ───────────────────────────────────────────────────────────────────────────
34
- # GLOBALS
35
- # ───────────────────────────────────────────────────────────────────────────
36
- dtype = torch.float16
37
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
- bank = get_bank() # shared singleton
39
 
40
- _t5_tok: Optional[T5Tokenizer] = None
41
- _t5_mod: Optional[T5EncoderModel] = None
42
- _pipe: Optional[StableDiffusionXLPipeline] = None
43
 
44
  SCHEDULERS = {
45
- "DPM++ 2M": DPMSolverMultistepScheduler,
46
- "DDIM": DDIMScheduler,
47
- "Euler": EulerDiscreteScheduler,
48
  }
49
 
50
- # easy access to adapter repo metadata
51
  clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
52
  clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
53
- repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
54
- repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
55
- conf_l = T5_SHUNT_REPOS["clip_l"]["config"]
56
- conf_g = T5_SHUNT_REPOS["clip_g"]["config"]
57
 
58
 
59
- # ───────────────────────────────────────────────────────────────────────────
60
- # HELPERs
61
- # ───────────────────────────────────────────────────────────────────────────
62
  def _init_t5():
63
  global _t5_tok, _t5_mod
64
  if _t5_tok is None:
65
  _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
66
- _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
 
67
 
68
 
69
  def _init_pipe():
@@ -71,124 +59,82 @@ def _init_pipe():
71
  if _pipe is None:
72
  _pipe = StableDiffusionXLPipeline.from_pretrained(
73
  "stabilityai/stable-diffusion-xl-base-1.0",
74
- torch_dtype=dtype,
75
- use_safetensors=True,
76
- variant="fp16",
77
  ).to(device)
78
  _pipe.enable_xformers_memory_efficient_attention()
79
 
80
 
81
- def load_adapter(repo: str, filename: str, cfg: dict):
82
- """load a TwoStreamShuntAdapter from HF Hub safetensors"""
83
- path = hf_hub_download(repo_id=repo, filename=filename)
84
- model = TwoStreamShuntAdapter(cfg).eval()
85
- tensors = load_file(path)
86
- model.load_state_dict(tensors)
87
  return model.to(device)
88
 
89
 
90
  def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
91
  if isinstance(mat, torch.Tensor):
92
  mat = mat.detach().cpu().numpy()
93
-
94
  if mat.ndim == 1:
95
  mat = mat[None, :]
96
- elif mat.ndim >= 3: # (B,T,D) β†’ mean over B
97
  mat = mat.mean(axis=0)
98
 
99
- plt.figure(figsize=(8, 4), dpi=120)
100
  plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
101
- plt.title(title)
102
  plt.colorbar(shrink=0.7)
103
  plt.tight_layout()
104
 
105
  buf = io.BytesIO()
106
- plt.savefig(buf, format="png")
107
- plt.close()
108
- buf.seek(0)
109
  return np.array(Image.open(buf))
110
 
111
 
112
- def encode_prompt_sd_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
113
- """Return CLIP-L, CLIP-G (and negative) embeddings from SDXL pipeline."""
114
- tok_l = pipe.tokenizer(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
115
- tok_g = pipe.tokenizer_2(prompt,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
116
- ntok_l = pipe.tokenizer(negative, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
117
- ntok_g = pipe.tokenizer_2(negative,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
 
 
 
118
 
119
  with torch.no_grad():
120
- clip_l = pipe.text_encoder(tok_l)[0] # (1,77,768)
121
- nclip_l= pipe.text_encoder(ntok_l)[0]
122
- out_g = pipe.text_encoder_2(tok_g, output_hidden_states=False)
123
- clip_g, pooled = out_g[1], out_g[0]
124
- nout_g = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
125
- nclip_g, npooled = nout_g[1], nout_g[0]
 
126
 
127
  return {"clip_l": clip_l, "clip_g": clip_g,
128
- "neg_l": nclip_l, "neg_g": nclip_g,
129
- "pooled": pooled, "neg_pooled": npooled}
130
 
131
 
132
- def adapter_forward(adapter, t5_seq, clip_seq, cfg):
133
- with torch.no_grad():
134
- out = adapter(t5_seq.float(), clip_seq.float())
135
- # unify outputs
136
- anchor, delta, log_sigma, *_, tau, g_pred, gate = (
137
- out + (None,) * 8)[:8] # pad to length 8
138
- delta = delta * cfg["delta_scale"]
139
- gate = torch.sigmoid(gate * g_pred * cfg["gpred_scale"]) * cfg["gate_prob"]
140
- final_delta = delta * cfg["strength"] * gate
141
- mod = clip_seq + final_delta.to(dtype)
142
-
143
- if cfg["sigma_scale"] > 0:
144
- sigma = torch.exp(log_sigma * cfg["sigma_scale"])
145
- mod += torch.randn_like(mod) * sigma.to(dtype)
146
- if cfg["use_anchor"]:
147
- mod = mod * (1 - gate) + anchor.to(dtype) * gate
148
- if cfg["noise"] > 0:
149
- mod += torch.randn_like(mod) * cfg["noise"]
150
- return mod, final_delta, gate, g_pred, tau
151
-
152
-
153
- # ───────────────────────────────────────────────────────────────────────────
154
- # MAIN INFERENCE
155
- # ───────────────────────────────────────────────────────────────────────────
156
- def infer(prompt, negative_prompt,
157
- adapter_l_file, adapter_g_file,
158
- strength, delta_scale, sigma_scale,
159
- gpred_scale, noise, gate_prob, use_anchor,
160
- steps, cfg_scale, scheduler_name,
161
- width, height, seed):
162
 
163
  torch.cuda.empty_cache()
164
  _init_t5(); _init_pipe()
165
 
166
- # scheduler
167
  if scheduler_name in SCHEDULERS:
168
  _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
169
 
170
- # RNG
171
- generator = None
172
- if seed != -1:
173
- generator = torch.Generator(device=device).manual_seed(seed)
174
- torch.manual_seed(seed); np.random.seed(seed)
175
-
176
- # T5 embeddings (semantic guidance)
177
- t5_ids = _t5_tok(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
178
- t5_seq = _t5_mod(t5_ids).last_hidden_state # (1,77,768)
179
-
180
- # CLIP embeddings from SDXL
181
- embeds = encode_prompt_sd_xl(_pipe, prompt, negative_prompt)
182
-
183
- # ------------------------------------------------------------------
184
- # LOAD adapters (if any)
185
- cfg_common = dict(
186
- strength=strength, delta_scale=delta_scale, sigma_scale=sigma_scale,
187
- gpred_scale=gpred_scale, noise=noise, gate_prob=gate_prob,
188
- use_anchor=use_anchor,
189
- )
190
 
191
- # --- STEP 0: build shift config -----------------------------------------
192
  cfg_shift = ShiftConfig(
193
  prompt = prompt,
194
  seed = seed,
@@ -200,66 +146,76 @@ def infer(prompt, negative_prompt,
200
  use_anchor = use_anchor,
201
  guidance_scale = gpred_scale,
202
  )
203
-
204
- # --- STEP 1: encoder embeddings -----------------------------------------
205
  t5_seq = ConditioningShifter.extract_encoder_embeddings(
206
  {"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
207
  device, cfg_shift
208
  )
209
-
210
- # --- STEP 2: run adapters -----------------------------------------------
211
- outputs = []
 
 
 
212
  if adapter_l_file and adapter_l_file != "None":
213
  ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
214
  outputs.append(ConditioningShifter.run_adapter(
215
  ada_l, t5_seq, embeds["clip_l"],
216
  cfg_shift.guidance_scale, "clip_l", (0, 768)))
217
-
218
  if adapter_g_file and adapter_g_file != "None":
219
  ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
220
  outputs.append(ConditioningShifter.run_adapter(
221
  ada_g, t5_seq, embeds["clip_g"],
222
  cfg_shift.guidance_scale, "clip_g", (768, 2048)))
223
-
224
- # --- STEP 3: apply mods --------------------------------------------------
225
  clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
 
 
 
 
 
226
  for out in outputs:
227
- tgt = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
228
- mod = ConditioningShifter.apply_modifications(tgt, [out], cfg_shift)
229
  if out.adapter_type == "clip_l":
230
  clip_l_mod = mod
231
  else:
232
  clip_g_mod = mod
 
 
233
 
234
-
235
- # concatenate for SDXL
236
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
237
  neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
238
 
239
- # SDXL generation
240
  image = _pipe(
241
- prompt_embeds = prompt_embeds,
242
- negative_prompt_embeds = neg_embeds,
243
- pooled_prompt_embeds = embeds["pooled"],
244
  negative_pooled_prompt_embeds = embeds["neg_pooled"],
245
- num_inference_steps=steps, guidance_scale=cfg_scale,
246
- width=width, height=height, generator=generator
 
247
  ).images[0]
248
 
249
- # viz
250
- delta_l_img = plot_heat(delta_l.squeeze(), "Ξ” CLIP-L")
251
- gate_l_img = plot_heat(gate_l.squeeze().mean(-1, keepdims=True), "Gate L")
252
- delta_g_img = plot_heat(delta_g.squeeze(), "Ξ” CLIP-G")
253
- gate_g_img = plot_heat(gate_g.squeeze().mean(-1, keepdims=True), "Gate G")
 
 
 
 
 
254
 
255
- stats_l = f"g_pred_L={g_pred_l.item():.3f} | Ο„_L={tau_l.item():.3f}"
256
- stats_g = f"g_pred_G={g_pred_g.item():.3f} | Ο„_G={tau_g.item():.3f}"
257
  return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
258
 
259
 
260
- # ───────────────────────────────────────────────────────────────────────────
261
- # GRADIO UI
262
- # ───────────────────────────────────────────────────────────────────────────
263
  def create_interface():
264
  with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
265
  gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")
@@ -267,65 +223,65 @@ def create_interface():
267
  with gr.Row():
268
  with gr.Column(scale=1):
269
  gr.Markdown("### Prompts")
270
- prompt = gr.Textbox(label="Prompt", lines=3,
271
- value="a futuristic control station with holographic displays")
272
- negative_prompt = gr.Textbox(label="Negative", lines=2,
273
- value="blurry, low quality, distorted")
274
 
275
  gr.Markdown("### Adapters")
276
- adapter_l = gr.Dropdown(["None"]+clip_l_opts, value="t5-vit-l-14-dual_shunt_caption.safetensors",
277
- label="CLIP-L Adapter")
278
- adapter_g = gr.Dropdown(["None"]+clip_g_opts, value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
279
- label="CLIP-G Adapter")
 
 
280
 
281
  gr.Markdown("### Adapter Controls")
282
- strength = gr.Slider(0, 10, 4.0, 0.01, label="Strength")
283
- delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ” scale")
284
- sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Οƒ scale")
285
- gpred_scale = gr.Slider(0, 20, 2.0, 0.01, label="g_pred scale")
286
- noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
287
- gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
288
- use_anchor = gr.Checkbox(True, label="Use anchor mix")
289
 
290
  gr.Markdown("### Generation")
291
  with gr.Row():
292
- steps = gr.Slider(1, 50, 20, 1, label="Steps")
293
- cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
294
- scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler")
 
295
  with gr.Row():
296
- width = gr.Slider(512, 1536, 1024, 64, label="Width")
297
- height = gr.Slider(512, 1536, 1024, 64, label="Height")
298
- seed = gr.Number(-1, label="Seed (-1=random)")
299
 
300
- go_btn = gr.Button("πŸš€ Generate", variant="primary")
301
 
302
  with gr.Column(scale=1):
303
- out_img = gr.Image(label="Result", height=400)
304
- gr.Markdown("### Adapter Diagnostics")
305
- delta_l_i = gr.Image(label="Ξ” L", height=180)
306
- gate_l_i = gr.Image(label="Gate L", height=180)
307
- delta_g_i = gr.Image(label="Ξ” G", height=180)
308
- gate_g_i = gr.Image(label="Gate G", height=180)
309
- stats_l = gr.Textbox(label="Stats L", interactive=False)
310
- stats_g = gr.Textbox(label="Stats G", interactive=False)
311
 
312
  def _run(*args):
313
- pl , npl = args[0], args[1]
314
- al, ag = (None if v=="None" else v for v in args[2:4])
315
  return infer(pl, npl, al, ag, *args[4:])
316
 
317
- go_btn.click(
318
- _run,
319
- inputs=[prompt, negative_prompt, adapter_l, adapter_g,
320
- strength, delta_scale, sigma_scale, gpred_scale,
321
- noise, gate_prob, use_anchor, steps, cfg_scale,
322
- scheduler, width, height, seed],
323
- outputs=[out_img, delta_l_i, gate_l_i, delta_g_i, gate_g_i,
324
- stats_l, stats_g]
325
  )
326
  return demo
327
 
328
 
329
- # ───────────────────────────────────────────────────────────────────────────
330
  if __name__ == "__main__":
331
  create_interface().launch()
 
1
+ # app.py ────────────────────────────────────────────────────────────────
2
+ import io, warnings, numpy as np, matplotlib.pyplot as plt
3
  from pathlib import Path
4
+ from typing import Dict, List, Optional, Tuple
5
 
6
  import gradio as gr
7
+ import torch, torch.nn.functional as F
 
8
  from PIL import Image
 
 
 
9
  from transformers import T5Tokenizer, T5EncoderModel
10
  from diffusers import (
11
  StableDiffusionXLPipeline,
12
+ DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler,
 
 
13
  )
14
  from huggingface_hub import hf_hub_download
15
  from safetensors.torch import load_file
16
 
 
17
  # local modules
18
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
19
+ from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput
20
+ from embedding_manager import get_bank
21
  from configs import T5_SHUNT_REPOS
 
22
 
23
  warnings.filterwarnings("ignore")
24
 
25
+ # ─── GLOBALS ────────────────────────────────────────────────────────────
26
+ dtype = torch.float16
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
 
29
+ _bank = get_bank() # singleton – optional caching
 
 
 
 
 
30
 
31
+ _t5_tok: Optional[T5Tokenizer] = None
32
+ _t5_mod: Optional[T5EncoderModel] = None
33
+ _pipe : Optional[StableDiffusionXLPipeline] = None
34
 
35
  SCHEDULERS = {
36
+ "DPM++ 2M": DPMSolverMultistepScheduler,
37
+ "DDIM": DDIMScheduler,
38
+ "Euler": EulerDiscreteScheduler,
39
  }
40
 
41
+ # adapter-meta from configs.py
42
  clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
43
  clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
44
+ repo_l, conf_l = T5_SHUNT_REPOS["clip_l"]["repo"], T5_SHUNT_REPOS["clip_l"]["config"]
45
+ repo_g, conf_g = T5_SHUNT_REPOS["clip_g"]["repo"], T5_SHUNT_REPOS["clip_g"]["config"]
 
 
46
 
47
 
48
+ # ─── INITIALISERS ────────────────────────────────────────────────────────
 
 
49
  def _init_t5():
50
  global _t5_tok, _t5_mod
51
  if _t5_tok is None:
52
  _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
53
+ _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base") \
54
+ .to(device).eval()
55
 
56
 
57
  def _init_pipe():
 
59
  if _pipe is None:
60
  _pipe = StableDiffusionXLPipeline.from_pretrained(
61
  "stabilityai/stable-diffusion-xl-base-1.0",
62
+ torch_dtype=dtype, variant="fp16", use_safetensors=True
 
 
63
  ).to(device)
64
  _pipe.enable_xformers_memory_efficient_attention()
65
 
66
 
67
+ # ─── HELPERS ─────────────────────────────────────────────────────────────
68
+ def load_adapter(repo: str, filename: str, cfg: dict,
69
+ device: torch.device) -> TwoStreamShuntAdapter:
70
+ path = hf_hub_download(repo_id=repo, filename=filename)
71
+ model = TwoStreamShuntAdapter(cfg).eval()
72
+ model.load_state_dict(load_file(path))
73
  return model.to(device)
74
 
75
 
76
  def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
77
  if isinstance(mat, torch.Tensor):
78
  mat = mat.detach().cpu().numpy()
 
79
  if mat.ndim == 1:
80
  mat = mat[None, :]
81
+ elif mat.ndim >= 3:
82
  mat = mat.mean(axis=0)
83
 
84
+ plt.figure(figsize=(7, 3.3), dpi=110)
85
  plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
86
+ plt.title(title, fontsize=10)
87
  plt.colorbar(shrink=0.7)
88
  plt.tight_layout()
89
 
90
  buf = io.BytesIO()
91
+ plt.savefig(buf, format="png", bbox_inches="tight")
92
+ plt.close(); buf.seek(0)
 
93
  return np.array(Image.open(buf))
94
 
95
 
96
+ def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
97
+ tok_l = pipe.tokenizer (prompt, max_length=77, truncation=True,
98
+ padding="max_length", return_tensors="pt").input_ids.to(device)
99
+ tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True,
100
+ padding="max_length", return_tensors="pt").input_ids.to(device)
101
+ ntok_l = pipe.tokenizer (negative,max_length=77, truncation=True,
102
+ padding="max_length", return_tensors="pt").input_ids.to(device)
103
+ ntok_g = pipe.tokenizer_2(negative,max_length=77, truncation=True,
104
+ padding="max_length", return_tensors="pt").input_ids.to(device)
105
 
106
  with torch.no_grad():
107
+ clip_l = pipe.text_encoder(tok_l)[0]
108
+ neg_clip_l = pipe.text_encoder(ntok_l)[0]
109
+
110
+ g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False)
111
+ clip_g, pl = g_out[1], g_out[0]
112
+ ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
113
+ neg_clip_g, npl = ng_out[1], ng_out[0]
114
 
115
  return {"clip_l": clip_l, "clip_g": clip_g,
116
+ "neg_l": neg_clip_l, "neg_g": neg_clip_g,
117
+ "pooled": pl, "neg_pooled": npl}
118
 
119
 
120
+ # ─── INFERENCE ───────────────────────────────────────────────────────────
121
+ def infer(prompt: str, negative_prompt: str,
122
+ adapter_l_file: str, adapter_g_file: str,
123
+ strength: float, delta_scale: float, sigma_scale: float,
124
+ gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool,
125
+ steps: int, cfg_scale: float, scheduler_name: str,
126
+ width: int, height: int, seed: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  torch.cuda.empty_cache()
129
  _init_t5(); _init_pipe()
130
 
 
131
  if scheduler_name in SCHEDULERS:
132
  _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
133
 
134
+ generator = (torch.Generator(device=device).manual_seed(seed)
135
+ if seed != -1 else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ # build ShiftConfig (one per request)
138
  cfg_shift = ShiftConfig(
139
  prompt = prompt,
140
  seed = seed,
 
146
  use_anchor = use_anchor,
147
  guidance_scale = gpred_scale,
148
  )
149
+
150
+ # encoder (T5) embeddings
151
  t5_seq = ConditioningShifter.extract_encoder_embeddings(
152
  {"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
153
  device, cfg_shift
154
  )
155
+
156
+ # CLIP embeddings
157
+ embeds = encode_prompt_xl(_pipe, prompt, negative_prompt)
158
+
159
+ # run adapters --------------------------------------------------------
160
+ outputs: List[AdapterOutput] = []
161
  if adapter_l_file and adapter_l_file != "None":
162
  ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
163
  outputs.append(ConditioningShifter.run_adapter(
164
  ada_l, t5_seq, embeds["clip_l"],
165
  cfg_shift.guidance_scale, "clip_l", (0, 768)))
166
+
167
  if adapter_g_file and adapter_g_file != "None":
168
  ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
169
  outputs.append(ConditioningShifter.run_adapter(
170
  ada_g, t5_seq, embeds["clip_g"],
171
  cfg_shift.guidance_scale, "clip_g", (768, 2048)))
172
+
173
+ # apply modifications -------------------------------------------------
174
  clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
175
+ delta_viz = {"clip_l": torch.zeros_like(clip_l_mod),
176
+ "clip_g": torch.zeros_like(clip_g_mod)}
177
+ gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]),
178
+ "clip_g": torch.zeros_like(clip_g_mod[..., :1])}
179
+
180
  for out in outputs:
181
+ target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
182
+ mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift)
183
  if out.adapter_type == "clip_l":
184
  clip_l_mod = mod
185
  else:
186
  clip_g_mod = mod
187
+ delta_viz[out.adapter_type] = out.delta.detach()
188
+ gate_viz [out.adapter_type] = out.gate.detach()
189
 
190
+ # prepare for SDXL ----------------------------------------------------
 
191
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
192
  neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
193
 
 
194
  image = _pipe(
195
+ prompt_embeds = prompt_embeds,
196
+ negative_prompt_embeds = neg_embeds,
197
+ pooled_prompt_embeds = embeds["pooled"],
198
  negative_pooled_prompt_embeds = embeds["neg_pooled"],
199
+ num_inference_steps = steps,
200
+ guidance_scale = cfg_scale,
201
+ width = width, height = height, generator = generator
202
  ).images[0]
203
 
204
+ # diagnostics ---------------------------------------------------------
205
+ delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Ξ” CLIP-L")
206
+ gate_l_img = plot_heat(gate_viz ["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L")
207
+ delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Ξ” CLIP-G")
208
+ gate_g_img = plot_heat(gate_viz ["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G")
209
+
210
+ stats_l = (f"Ο„Μ„_L = {outputs[0].tau.mean().item():.3f}"
211
+ if outputs and outputs[0].adapter_type == "clip_l" else "-")
212
+ stats_g = (f"Ο„Μ„_G = {outputs[-1].tau.mean().item():.3f}"
213
+ if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-")
214
 
 
 
215
  return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
216
 
217
 
218
+ # ─── GRADIO UI ────────────────────────────────────────────────────────────
 
 
219
  def create_interface():
220
  with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
221
  gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")
 
223
  with gr.Row():
224
  with gr.Column(scale=1):
225
  gr.Markdown("### Prompts")
226
+ prompt = gr.Textbox(label="Prompt", lines=3,
227
+ value="a futuristic control station with holographic displays")
228
+ negative = gr.Textbox(label="Negative", lines=2,
229
+ value="blurry, low quality, distorted")
230
 
231
  gr.Markdown("### Adapters")
232
+ adapter_l = gr.Dropdown(["None"] + clip_l_opts,
233
+ value="t5-vit-l-14-dual_shunt_caption.safetensors",
234
+ label="CLIP-L Adapter")
235
+ adapter_g = gr.Dropdown(["None"] + clip_g_opts,
236
+ value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
237
+ label="CLIP-G Adapter")
238
 
239
  gr.Markdown("### Adapter Controls")
240
+ strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength")
241
+ delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ” scale")
242
+ sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Οƒ scale")
243
+ gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale")
244
+ noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
245
+ gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
246
+ use_anchor = gr.Checkbox(True, label="Use anchor mix")
247
 
248
  gr.Markdown("### Generation")
249
  with gr.Row():
250
+ steps = gr.Slider(1, 50, 20, 1, label="Steps")
251
+ cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
252
+ scheduler = gr.Dropdown(list(SCHEDULERS.keys()),
253
+ value="DPM++ 2M", label="Scheduler")
254
  with gr.Row():
255
+ width = gr.Slider(512, 1536, 1024, 64, label="Width")
256
+ height = gr.Slider(512, 1536, 1024, 64, label="Height")
257
+ seed = gr.Number(-1, label="Seed (-1 β†’ random)", precision=0)
258
 
259
+ run_btn = gr.Button("πŸš€ Generate", variant="primary")
260
 
261
  with gr.Column(scale=1):
262
+ out_img = gr.Image(label="Result", height=400)
263
+ gr.Markdown("### Diagnostics")
264
+ delta_l = gr.Image(label="Ξ” L", height=180)
265
+ gate_l = gr.Image(label="Gate L", height=180)
266
+ delta_g = gr.Image(label="Ξ” G", height=180)
267
+ gate_g = gr.Image(label="Gate G", height=180)
268
+ stats_l = gr.Textbox(label="Stats L", interactive=False)
269
+ stats_g = gr.Textbox(label="Stats G", interactive=False)
270
 
271
  def _run(*args):
272
+ pl, npl = args[0], args[1]
273
+ al, ag = (None if v == "None" else v for v in args[2:4])
274
  return infer(pl, npl, al, ag, *args[4:])
275
 
276
+ run_btn.click(
277
+ fn=_run,
278
+ inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale,
279
+ sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps,
280
+ cfg_scale, scheduler, width, height, seed],
281
+ outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
 
 
282
  )
283
  return demo
284
 
285
 
 
286
  if __name__ == "__main__":
287
  create_interface().launch()