AbstractPhil commited on
Commit
e543e33
Β·
verified Β·
1 Parent(s): 40a1f37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -391
app.py CHANGED
@@ -1,432 +1,307 @@
1
- import spaces
2
- import torch
 
 
 
3
  import gradio as gr
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  from PIL import Image
 
 
 
7
  from transformers import T5Tokenizer, T5EncoderModel
8
- from diffusers import StableDiffusionXLPipeline, DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
9
- from safetensors.torch import load_file
 
 
 
 
10
  from huggingface_hub import hf_hub_download
 
 
 
 
11
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
12
  from configs import T5_SHUNT_REPOS
13
- import io
 
 
 
14
 
15
- # ─── Global Variables ─────────────────────────────────────────
16
- t5_tok = None
17
- t5_mod = None
18
- pipe = None
 
 
 
 
 
 
19
 
20
- # Available schedulers
21
  SCHEDULERS = {
22
- "DPM++ 2M": DPMSolverMultistepScheduler,
23
- "DDIM": DDIMScheduler,
24
- "Euler": EulerDiscreteScheduler,
25
  }
26
 
27
- # ─── Adapter Configs ──────────────────────────────────────────
28
  clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
29
  clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
30
- repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
31
- repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
32
- config_l = T5_SHUNT_REPOS["clip_l"]["config"]
33
- config_g = T5_SHUNT_REPOS["clip_g"]["config"]
34
-
35
- # ─── Helper Functions ─────────────────────────────────────────
36
- def load_adapter(repo, filename, config, device):
37
- """Load adapter from safetensors file"""
38
- from safetensors.torch import safe_open
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  path = hf_hub_download(repo_id=repo, filename=filename)
40
-
41
- model = TwoStreamShuntAdapter(config).eval()
42
- tensors = {}
43
- with safe_open(path, framework="pt", device="cpu") as f:
44
- for key in f.keys():
45
- tensors[key] = f.get_tensor(key)
46
  model.load_state_dict(tensors)
47
  return model.to(device)
48
 
49
- def plot_heat(mat, title):
50
- """Create heatmap visualization with proper shape handling"""
51
- # Handle different input shapes
52
  if isinstance(mat, torch.Tensor):
53
  mat = mat.detach().cpu().numpy()
54
-
55
- # Ensure we have a 2D array for visualization
56
- if len(mat.shape) == 1:
57
- # 1D array - reshape to single row
58
- mat = mat.reshape(1, -1)
59
- elif len(mat.shape) == 3:
60
- # 3D array - average over batch dimension
61
- if mat.shape[0] == 1:
62
- mat = mat.squeeze(0)
63
- else:
64
- mat = mat.mean(axis=0)
65
- elif len(mat.shape) > 3:
66
- # Flatten higher dimensions
67
- mat = mat.reshape(-1, mat.shape[-1])
68
-
69
- # Create figure with proper DPI
70
- plt.figure(figsize=(8, 4), dpi=100)
71
- plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper", interpolation='nearest')
72
- plt.title(title, fontsize=12, fontweight='bold')
73
- plt.xlabel("Token Position")
74
- plt.ylabel("Feature Dimension")
75
- plt.colorbar(shrink=0.8)
76
  plt.tight_layout()
77
-
78
- # Convert to PIL Image
79
  buf = io.BytesIO()
80
- plt.savefig(buf, format="png", bbox_inches='tight', dpi=100)
81
- buf.seek(0)
82
- pil_image = Image.open(buf)
83
  plt.close()
84
-
85
- # Convert to numpy array for Gradio
86
- return np.array(pil_image)
87
-
88
- def encode_sdxl_prompt(pipe, prompt, negative_prompt, device):
89
- """Generate CLIP-L and CLIP-G embeddings using SDXL's text encoders"""
90
-
91
- # Tokenize for both encoders
92
- tokens_l = pipe.tokenizer(
93
- prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
94
- ).input_ids.to(device)
95
-
96
- tokens_g = pipe.tokenizer_2(
97
- prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
98
- ).input_ids.to(device)
99
-
100
- neg_tokens_l = pipe.tokenizer(
101
- negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
102
- ).input_ids.to(device)
103
-
104
- neg_tokens_g = pipe.tokenizer_2(
105
- negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
106
- ).input_ids.to(device)
107
-
108
  with torch.no_grad():
109
- # CLIP-L: [0] = sequence, [1] = pooled
110
- clip_l_output = pipe.text_encoder(tokens_l, output_hidden_states=False)
111
- clip_l_embeds = clip_l_output[0]
112
-
113
- neg_clip_l_output = pipe.text_encoder(neg_tokens_l, output_hidden_states=False)
114
- neg_clip_l_embeds = neg_clip_l_output[0]
115
-
116
- # CLIP-G: [0] = pooled, [1] = sequence
117
- clip_g_output = pipe.text_encoder_2(tokens_g, output_hidden_states=False)
118
- clip_g_embeds = clip_g_output[1] # sequence embeddings
119
- pooled_embeds = clip_g_output[0] # pooled embeddings
120
-
121
- neg_clip_g_output = pipe.text_encoder_2(neg_tokens_g, output_hidden_states=False)
122
- neg_clip_g_embeds = neg_clip_g_output[1]
123
- neg_pooled_embeds = neg_clip_g_output[0]
124
-
125
- return {
126
- "clip_l": clip_l_embeds,
127
- "clip_g": clip_g_embeds,
128
- "neg_clip_l": neg_clip_l_embeds,
129
- "neg_clip_g": neg_clip_g_embeds,
130
- "pooled": pooled_embeds,
131
- "neg_pooled": neg_pooled_embeds
132
- }
133
-
134
- # ─── Main Inference Function ──────────────────────────────────
135
- @spaces.GPU
136
- def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, delta_scale,
137
- sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale,
138
- scheduler_name, width, height, seed):
139
-
140
- global t5_tok, t5_mod, pipe
141
- device = torch.device("cuda")
142
- dtype = torch.float16
143
-
144
- # Initialize models
145
- if t5_tok is None:
146
- t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
147
- t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
148
-
149
- if pipe is None:
150
- pipe = StableDiffusionXLPipeline.from_pretrained(
151
- "stabilityai/stable-diffusion-xl-base-1.0",
152
- torch_dtype=dtype,
153
- variant="fp16",
154
- use_safetensors=True
155
- ).to(device)
156
-
157
- # Set seed
 
 
 
158
  if seed != -1:
159
- torch.manual_seed(seed)
160
- np.random.seed(seed)
161
  generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  else:
163
- generator = None
164
-
165
- # Set scheduler
166
- if scheduler_name in SCHEDULERS:
167
- pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
168
-
169
- # Get T5 embeddings
170
- t5_ids = t5_tok(
171
- prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True
172
- ).input_ids.to(device)
173
-
174
- with torch.no_grad():
175
- t5_seq = t5_mod(t5_ids).last_hidden_state
176
-
177
- # Get CLIP embeddings
178
- clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
179
-
180
- # Load and apply adapters
181
- if(adapter_l_file == "t5-vit-l-14-dual_shunt_booru_13_000_000.safetensors" or adapter_l_file == "t5-vit-l-14-dual_shunt_booru_51_200_000.safetensors"):
182
- config_l["heads"] = 4
183
- else:
184
- config_l["heads"] = 12
185
- adapter_l = load_adapter(repo_l, adapter_l_file, config_l, device) if adapter_l_file else None
186
- adapter_g = load_adapter(repo_g, adapter_g_file, config_g, device) if adapter_g_file else None
187
-
188
- # Apply CLIP-L adapter
189
- if adapter_l is not None:
190
- with torch.no_grad():
191
- # Run adapter forward pass
192
- adapter_output = adapter_l(t5_seq.float(), clip_embeds["clip_l"].float())
193
-
194
- # Unpack outputs (ensure correct number of outputs)
195
- if len(adapter_output) == 8:
196
- anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_output
197
- else:
198
- # Handle different return formats
199
- anchor_l = adapter_output[0]
200
- delta_l = adapter_output[1]
201
- log_sigma_l = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_l)
202
- gate_l = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_l)
203
- tau_l = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
204
- g_pred_l = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
205
-
206
- # Scale delta values
207
- delta_l = delta_l * delta_scale
208
-
209
- # Apply g_pred scaling to gate
210
- gate_l = gate_l * g_pred_l * gpred_scale
211
-
212
- # Apply gate scaling
213
- gate_l_scaled = torch.sigmoid(gate_l) * gate_prob
214
-
215
- # Compute final delta with strength and gate
216
- delta_l_final = delta_l * strength * gate_l_scaled
217
-
218
- # Apply delta to embeddings
219
- clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype)
220
-
221
- # Apply sigma-based noise if specified
222
- if sigma_scale > 0:
223
- sigma_l = torch.exp(log_sigma_l * sigma_scale)
224
- clip_l_mod += torch.randn_like(clip_l_mod) * sigma_l.to(dtype)
225
-
226
- # Apply anchor mixing if enabled
227
- if use_anchor:
228
- clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype)
229
-
230
- # Add additional noise if specified
231
- if noise > 0:
232
- clip_l_mod += torch.randn_like(clip_l_mod) * noise
233
- else:
234
- clip_l_mod = clip_embeds["clip_l"]
235
- delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
236
- gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
237
- g_pred_l = torch.tensor(0.0)
238
- tau_l = torch.tensor(0.0)
239
-
240
- # Apply CLIP-G adapter
241
- if adapter_g is not None:
242
- with torch.no_grad():
243
- # Run adapter forward pass
244
- adapter_output = adapter_g(t5_seq.float(), clip_embeds["clip_g"].float())
245
-
246
- # Unpack outputs (ensure correct number of outputs)
247
- if len(adapter_output) == 8:
248
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_output
249
- else:
250
- # Handle different return formats
251
- anchor_g = adapter_output[0]
252
- delta_g = adapter_output[1]
253
- log_sigma_g = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_g)
254
- gate_g = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_g)
255
- tau_g = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
256
- g_pred_g = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
257
-
258
- # Scale delta values
259
- delta_g = delta_g * delta_scale
260
-
261
- # Apply g_pred scaling to gate
262
- gate_g = gate_g * g_pred_g * gpred_scale
263
-
264
- # Apply gate scaling
265
- gate_g_scaled = torch.sigmoid(gate_g) * gate_prob
266
-
267
- # Compute final delta with strength and gate
268
- delta_g_final = delta_g * strength * gate_g_scaled
269
-
270
- # Apply delta to embeddings
271
- clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype)
272
-
273
- # Apply sigma-based noise if specified
274
- if sigma_scale > 0:
275
- sigma_g = torch.exp(log_sigma_g * sigma_scale)
276
- clip_g_mod += torch.randn_like(clip_g_mod) * sigma_g.to(dtype)
277
-
278
- # Apply anchor mixing if enabled
279
- if use_anchor:
280
- clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype)
281
-
282
- # Add additional noise if specified
283
- if noise > 0:
284
- clip_g_mod += torch.randn_like(clip_g_mod) * noise
285
  else:
286
- clip_g_mod = clip_embeds["clip_g"]
287
- delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
288
- gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
289
- g_pred_g = torch.tensor(0.0)
290
- tau_g = torch.tensor(0.0)
291
-
292
- # Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
293
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
294
- neg_embeds = torch.cat([clip_embeds["neg_clip_l"], clip_embeds["neg_clip_g"]], dim=-1)
295
-
296
- # Generate image
297
- image = pipe(
298
- prompt_embeds=prompt_embeds,
299
- pooled_prompt_embeds=clip_embeds["pooled"],
300
- negative_prompt_embeds=neg_embeds,
301
- negative_pooled_prompt_embeds=clip_embeds["neg_pooled"],
302
- num_inference_steps=steps,
303
- guidance_scale=cfg_scale,
304
- width=width,
305
- height=height,
306
- num_images_per_prompt=1,
307
- generator=generator
308
  ).images[0]
309
-
310
- # Create visualizations
311
- delta_l_viz = plot_heat(delta_l_final.squeeze(), "CLIP-L Delta Values")
312
- gate_l_viz = plot_heat(gate_l_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-L Gate Activations")
313
- delta_g_viz = plot_heat(delta_g_final.squeeze(), "CLIP-G Delta Values")
314
- gate_g_viz = plot_heat(gate_g_scaled.squeeze().mean(dim=-1, keepdim=True), "CLIP-G Gate Activations")
315
-
316
- # Statistics
317
- stats_l = f"g_pred_l: {float(g_pred_l.mean().item() if hasattr(g_pred_l, 'mean') else g_pred_l):.3f}, Ο„_l: {float(tau_l.mean().item() if hasattr(tau_l, 'mean') else tau_l):.3f}"
318
- stats_g = f"g_pred_g: {float(g_pred_g.mean().item() if hasattr(g_pred_g, 'mean') else g_pred_g):.3f}, Ο„_g: {float(tau_g.mean().item() if hasattr(tau_g, 'mean') else tau_g):.3f}"
319
-
320
- return image, delta_l_viz, gate_l_viz, delta_g_viz, gate_g_viz, stats_l, stats_g
321
-
322
- # ─── Gradio Interface ─────────────────────────────────────────
 
323
  def create_interface():
324
- with gr.Blocks(title="SDXL Dual Shunt Adapter", theme=gr.themes.Soft()) as demo:
325
- gr.Markdown("# 🧠 SDXL Dual Shunt Adapter")
326
- gr.Markdown("*Enhance SDXL generation using T5 semantic understanding to modify CLIP embeddings*")
327
-
328
  with gr.Row():
329
  with gr.Column(scale=1):
330
- # Prompts
331
- gr.Markdown("### πŸ“ Prompts")
332
- prompt = gr.Textbox(
333
- label="Prompt",
334
- value="a futuristic control station with holographic displays",
335
- lines=3,
336
- placeholder="Describe what you want to generate..."
337
- )
338
- negative_prompt = gr.Textbox(
339
- label="Negative Prompt",
340
- value="blurry, low quality, distorted",
341
- lines=2,
342
- placeholder="Describe what you want to avoid..."
343
- )
344
-
345
- # Adapters
346
- gr.Markdown("### βš™οΈ Adapters")
347
- adapter_l = gr.Dropdown(
348
- choices=["None"] + clip_l_opts,
349
- label="CLIP-L (768d) Adapter",
350
- value="t5-vit-l-14-dual_shunt_caption.safetensors",
351
- info="Choose adapter for CLIP-L embeddings"
352
- )
353
- adapter_g = gr.Dropdown(
354
- choices=["None"] + clip_g_opts,
355
- label="CLIP-G (1280d) Adapter",
356
- value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
357
- info="Choose adapter for CLIP-G embeddings"
358
- )
359
-
360
- # Controls
361
- gr.Markdown("### πŸŽ›οΈ Adapter Controls")
362
- strength = gr.Slider(0.0, 10.0, value=4.0, step=0.01, label="Adapter Strength")
363
- delta_scale = gr.Slider(-15.0, 15.0, value=0.2, step=0.1, label="Delta Scale", info="Scales the delta values, recommended 1")
364
- sigma_scale = gr.Slider(0, 15.0, value=0.1, step=0.1, label="Sigma Scale", info="Scales the noise variance, recommended 1")
365
- gpred_scale = gr.Slider(0.0, 20.0, value=2.0, step=0.01, label="G-Pred Scale", info="Scales the gate prediction, recommended 2")
366
- noise = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Noise Injection")
367
- gate_prob = gr.Slider(0.0, 1.0, value=0.27, step=0.01, label="Gate Probability")
368
- use_anchor = gr.Checkbox(label="Use Anchor Points", value=True)
369
-
370
- # Generation Settings
371
- gr.Markdown("### 🎨 Generation Settings")
372
  with gr.Row():
373
- steps = gr.Slider(1, 50, value=20, step=1, label="Steps")
374
- cfg_scale = gr.Slider(1.0, 15.0, value=7.5, step=0.1, label="CFG Scale")
375
-
376
- scheduler_name = gr.Dropdown(
377
- choices=list(SCHEDULERS.keys()),
378
- value="DPM++ 2M",
379
- label="Scheduler"
380
- )
381
-
382
  with gr.Row():
383
- width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
384
- height = gr.Slider(512, 1536, value=1024, step=64, label="Height")
385
-
386
- seed = gr.Number(value=-1, label="Seed (-1 for random)", precision=0)
387
-
388
- generate_btn = gr.Button("πŸš€ Generate Image", variant="primary", size="lg")
389
-
390
  with gr.Column(scale=1):
391
- # Output
392
- gr.Markdown("### πŸ–ΌοΈ Generated Image")
393
- output_image = gr.Image(label="Result", height=400, show_label=False)
394
-
395
- # Visualizations
396
- gr.Markdown("### πŸ“Š Adapter Analysis")
397
- with gr.Row():
398
- delta_l_img = gr.Image(label="CLIP-L Deltas", height=200)
399
- gate_l_img = gr.Image(label="CLIP-L Gates", height=200)
400
- with gr.Row():
401
- delta_g_img = gr.Image(label="CLIP-G Deltas", height=200)
402
- gate_g_img = gr.Image(label="CLIP-G Gates", height=200)
403
-
404
- # Statistics
405
- gr.Markdown("### πŸ“ˆ Statistics")
406
- stats_l_text = gr.Textbox(label="CLIP-L Metrics", interactive=False)
407
- stats_g_text = gr.Textbox(label="CLIP-G Metrics", interactive=False)
408
-
409
- # Event handler
410
- def run_generation(*args):
411
- # Process adapter selections
412
- processed_args = list(args)
413
- processed_args[2] = None if args[2] == "None" else args[2] # adapter_l
414
- processed_args[3] = None if args[3] == "None" else args[3] # adapter_g
415
- return infer(*processed_args)
416
-
417
- generate_btn.click(
418
- fn=run_generation,
419
- inputs=[
420
- prompt, negative_prompt, adapter_l, adapter_g, strength, delta_scale,
421
- sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale,
422
- scheduler_name, width, height, seed
423
- ],
424
- outputs=[output_image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l_text, stats_g_text]
425
  )
426
-
427
  return demo
428
 
429
- # ─── Launch ────────────────────────────────────────────────────
 
430
  if __name__ == "__main__":
431
- demo = create_interface()
432
- demo.launch()
 
1
+ # app.py ────────────────────────────────────────────────────────────────────
2
+ import io, os, json, math, random, warnings, gc, functools, hashlib
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional
5
+
6
  import gradio as gr
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from PIL import Image
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
  from transformers import T5Tokenizer, T5EncoderModel
14
+ from diffusers import (
15
+ StableDiffusionXLPipeline,
16
+ DDIMScheduler,
17
+ EulerDiscreteScheduler,
18
+ DPMSolverMultistepScheduler,
19
+ )
20
  from huggingface_hub import hf_hub_download
21
+ from safetensors.torch import load_file
22
+
23
+ # -------------------------------------------------------------------------
24
+ # local modules
25
  from two_stream_shunt_adapter import TwoStreamShuntAdapter
26
  from configs import T5_SHUNT_REPOS
27
+ from embedding_manager import get_bank # ← NEW
28
+
29
+ warnings.filterwarnings("ignore")
30
+
31
 
32
+ # ───────────────────────────────────────────────────────────────────────────
33
+ # GLOBALS
34
+ # ───────────────────────────────────────────────────────────────────────────
35
+ dtype = torch.float16
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ bank = get_bank() # shared singleton
38
+
39
+ _t5_tok: Optional[T5Tokenizer] = None
40
+ _t5_mod: Optional[T5EncoderModel] = None
41
+ _pipe: Optional[StableDiffusionXLPipeline] = None
42
 
 
43
  SCHEDULERS = {
44
+ "DPM++ 2M": DPMSolverMultistepScheduler,
45
+ "DDIM": DDIMScheduler,
46
+ "Euler": EulerDiscreteScheduler,
47
  }
48
 
49
+ # easy access to adapter repo metadata
50
  clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
51
  clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
52
+ repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
53
+ repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
54
+ conf_l = T5_SHUNT_REPOS["clip_l"]["config"]
55
+ conf_g = T5_SHUNT_REPOS["clip_g"]["config"]
56
+
57
+
58
+ # ───────────────────────────────────────────────────────────────────────────
59
+ # HELPERs
60
+ # ───────────────────────────────────────────────────────────────────────────
61
+ def _init_t5():
62
+ global _t5_tok, _t5_mod
63
+ if _t5_tok is None:
64
+ _t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
65
+ _t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
66
+
67
+
68
+ def _init_pipe():
69
+ global _pipe
70
+ if _pipe is None:
71
+ _pipe = StableDiffusionXLPipeline.from_pretrained(
72
+ "stabilityai/stable-diffusion-xl-base-1.0",
73
+ torch_dtype=dtype,
74
+ use_safetensors=True,
75
+ variant="fp16",
76
+ ).to(device)
77
+ _pipe.enable_xformers_memory_efficient_attention()
78
+
79
+
80
+ def load_adapter(repo: str, filename: str, cfg: dict):
81
+ """load a TwoStreamShuntAdapter from HF Hub safetensors"""
82
  path = hf_hub_download(repo_id=repo, filename=filename)
83
+ model = TwoStreamShuntAdapter(cfg).eval()
84
+ tensors = load_file(path)
 
 
 
 
85
  model.load_state_dict(tensors)
86
  return model.to(device)
87
 
88
+
89
+ def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
 
90
  if isinstance(mat, torch.Tensor):
91
  mat = mat.detach().cpu().numpy()
92
+
93
+ if mat.ndim == 1:
94
+ mat = mat[None, :]
95
+ elif mat.ndim >= 3: # (B,T,D) β†’ mean over B
96
+ mat = mat.mean(axis=0)
97
+
98
+ plt.figure(figsize=(8, 4), dpi=120)
99
+ plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
100
+ plt.title(title)
101
+ plt.colorbar(shrink=0.7)
 
 
 
 
 
 
 
 
 
 
 
 
102
  plt.tight_layout()
103
+
 
104
  buf = io.BytesIO()
105
+ plt.savefig(buf, format="png")
 
 
106
  plt.close()
107
+ buf.seek(0)
108
+ return np.array(Image.open(buf))
109
+
110
+
111
+ def encode_prompt_sd_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
112
+ """Return CLIP-L, CLIP-G (and negative) embeddings from SDXL pipeline."""
113
+ tok_l = pipe.tokenizer(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
114
+ tok_g = pipe.tokenizer_2(prompt,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
115
+ ntok_l = pipe.tokenizer(negative, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
116
+ ntok_g = pipe.tokenizer_2(negative,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
117
+
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with torch.no_grad():
119
+ clip_l = pipe.text_encoder(tok_l)[0] # (1,77,768)
120
+ nclip_l= pipe.text_encoder(ntok_l)[0]
121
+ out_g = pipe.text_encoder_2(tok_g, output_hidden_states=False)
122
+ clip_g, pooled = out_g[1], out_g[0]
123
+ nout_g = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
124
+ nclip_g, npooled = nout_g[1], nout_g[0]
125
+
126
+ return {"clip_l": clip_l, "clip_g": clip_g,
127
+ "neg_l": nclip_l, "neg_g": nclip_g,
128
+ "pooled": pooled, "neg_pooled": npooled}
129
+
130
+
131
+ def adapter_forward(adapter, t5_seq, clip_seq, cfg):
132
+ with torch.no_grad():
133
+ out = adapter(t5_seq.float(), clip_seq.float())
134
+ # unify outputs
135
+ anchor, delta, log_sigma, *_, tau, g_pred, gate = (
136
+ out + (None,) * 8)[:8] # pad to length 8
137
+ delta = delta * cfg["delta_scale"]
138
+ gate = torch.sigmoid(gate * g_pred * cfg["gpred_scale"]) * cfg["gate_prob"]
139
+ final_delta = delta * cfg["strength"] * gate
140
+ mod = clip_seq + final_delta.to(dtype)
141
+
142
+ if cfg["sigma_scale"] > 0:
143
+ sigma = torch.exp(log_sigma * cfg["sigma_scale"])
144
+ mod += torch.randn_like(mod) * sigma.to(dtype)
145
+ if cfg["use_anchor"]:
146
+ mod = mod * (1 - gate) + anchor.to(dtype) * gate
147
+ if cfg["noise"] > 0:
148
+ mod += torch.randn_like(mod) * cfg["noise"]
149
+ return mod, final_delta, gate, g_pred, tau
150
+
151
+
152
+ # ───────────────────────────────────────────────────────────────────────────
153
+ # MAIN INFERENCE
154
+ # ───────────────────────────────────────────────────────────────────────────
155
+ def infer(prompt, negative_prompt,
156
+ adapter_l_file, adapter_g_file,
157
+ strength, delta_scale, sigma_scale,
158
+ gpred_scale, noise, gate_prob, use_anchor,
159
+ steps, cfg_scale, scheduler_name,
160
+ width, height, seed):
161
+
162
+ torch.cuda.empty_cache()
163
+ _init_t5(); _init_pipe()
164
+
165
+ # scheduler
166
+ if scheduler_name in SCHEDULERS:
167
+ _pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
168
+
169
+ # RNG
170
+ generator = None
171
  if seed != -1:
 
 
172
  generator = torch.Generator(device=device).manual_seed(seed)
173
+ torch.manual_seed(seed); np.random.seed(seed)
174
+
175
+ # T5 embeddings (semantic guidance)
176
+ t5_ids = _t5_tok(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
177
+ t5_seq = _t5_mod(t5_ids).last_hidden_state # (1,77,768)
178
+
179
+ # CLIP embeddings from SDXL
180
+ embeds = encode_prompt_sd_xl(_pipe, prompt, negative_prompt)
181
+
182
+ # ------------------------------------------------------------------
183
+ # LOAD adapters (if any)
184
+ cfg_common = dict(
185
+ strength=strength, delta_scale=delta_scale, sigma_scale=sigma_scale,
186
+ gpred_scale=gpred_scale, noise=noise, gate_prob=gate_prob,
187
+ use_anchor=use_anchor,
188
+ )
189
+
190
+ # CLIP-L
191
+ if adapter_l_file and adapter_l_file != "None":
192
+ cfg_l = conf_l.copy(); cfg_l.update(cfg_common)
193
+ if "booru" in adapter_l_file: cfg_l["heads"] = 4
194
+ adapter_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
195
+ clip_l_mod, delta_l, gate_l, g_pred_l, tau_l = adapter_forward(
196
+ adapter_l, t5_seq, embeds["clip_l"], cfg_l)
197
  else:
198
+ clip_l_mod = embeds["clip_l"]; delta_l = torch.zeros_like(clip_l_mod)
199
+ gate_l = torch.zeros_like(clip_l_mod[..., :1]); g_pred_l = tau_l = torch.tensor(0.)
200
+
201
+ # CLIP-G
202
+ if adapter_g_file and adapter_g_file != "None":
203
+ cfg_g = conf_g.copy(); cfg_g.update(cfg_common)
204
+ adapter_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
205
+ clip_g_mod, delta_g, gate_g, g_pred_g, tau_g = adapter_forward(
206
+ adapter_g, t5_seq, embeds["clip_g"], cfg_g)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  else:
208
+ clip_g_mod = embeds["clip_g"]; delta_g = torch.zeros_like(clip_g_mod)
209
+ gate_g = torch.zeros_like(clip_g_mod[..., :1]); g_pred_g = tau_g = torch.tensor(0.)
210
+
211
+ # concatenate for SDXL
 
 
 
212
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
213
+ neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
214
+
215
+ # SDXL generation
216
+ image = _pipe(
217
+ prompt_embeds = prompt_embeds,
218
+ negative_prompt_embeds = neg_embeds,
219
+ pooled_prompt_embeds = embeds["pooled"],
220
+ negative_pooled_prompt_embeds = embeds["neg_pooled"],
221
+ num_inference_steps=steps, guidance_scale=cfg_scale,
222
+ width=width, height=height, generator=generator
 
 
 
 
223
  ).images[0]
224
+
225
+ # viz
226
+ delta_l_img = plot_heat(delta_l.squeeze(), "Ξ” CLIP-L")
227
+ gate_l_img = plot_heat(gate_l.squeeze().mean(-1, keepdims=True), "Gate L")
228
+ delta_g_img = plot_heat(delta_g.squeeze(), "Ξ” CLIP-G")
229
+ gate_g_img = plot_heat(gate_g.squeeze().mean(-1, keepdims=True), "Gate G")
230
+
231
+ stats_l = f"g_pred_L={g_pred_l.item():.3f} | Ο„_L={tau_l.item():.3f}"
232
+ stats_g = f"g_pred_G={g_pred_g.item():.3f} | Ο„_G={tau_g.item():.3f}"
233
+ return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
234
+
235
+
236
+ # ───────────────────────────────────────────────────────────────────────────
237
+ # GRADIO UI
238
+ # ───────────────────────────────────────────────────────────────────────────
239
  def create_interface():
240
+ with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
241
+ gr.Markdown("# 🧠 SDXL Dual-Shunt Tester")
242
+
 
243
  with gr.Row():
244
  with gr.Column(scale=1):
245
+ gr.Markdown("### Prompts")
246
+ prompt = gr.Textbox(label="Prompt", lines=3,
247
+ value="a futuristic control station with holographic displays")
248
+ negative_prompt = gr.Textbox(label="Negative", lines=2,
249
+ value="blurry, low quality, distorted")
250
+
251
+ gr.Markdown("### Adapters")
252
+ adapter_l = gr.Dropdown(["None"]+clip_l_opts, value="t5-vit-l-14-dual_shunt_caption.safetensors",
253
+ label="CLIP-L Adapter")
254
+ adapter_g = gr.Dropdown(["None"]+clip_g_opts, value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
255
+ label="CLIP-G Adapter")
256
+
257
+ gr.Markdown("### Adapter Controls")
258
+ strength = gr.Slider(0, 10, 4.0, 0.01, label="Strength")
259
+ delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ” scale")
260
+ sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Οƒ scale")
261
+ gpred_scale = gr.Slider(0, 20, 2.0, 0.01, label="g_pred scale")
262
+ noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
263
+ gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
264
+ use_anchor = gr.Checkbox(True, label="Use anchor mix")
265
+
266
+ gr.Markdown("### Generation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  with gr.Row():
268
+ steps = gr.Slider(1, 50, 20, 1, label="Steps")
269
+ cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
270
+ scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler")
 
 
 
 
 
 
271
  with gr.Row():
272
+ width = gr.Slider(512, 1536, 1024, 64, label="Width")
273
+ height = gr.Slider(512, 1536, 1024, 64, label="Height")
274
+ seed = gr.Number(-1, label="Seed (-1=random)")
275
+
276
+ go_btn = gr.Button("πŸš€ Generate", variant="primary")
277
+
 
278
  with gr.Column(scale=1):
279
+ out_img = gr.Image(label="Result", height=400)
280
+ gr.Markdown("### Adapter Diagnostics")
281
+ delta_l_i = gr.Image(label="Ξ” L", height=180)
282
+ gate_l_i = gr.Image(label="Gate L", height=180)
283
+ delta_g_i = gr.Image(label="Ξ” G", height=180)
284
+ gate_g_i = gr.Image(label="Gate G", height=180)
285
+ stats_l = gr.Textbox(label="Stats L", interactive=False)
286
+ stats_g = gr.Textbox(label="Stats G", interactive=False)
287
+
288
+ def _run(*args):
289
+ pl , npl = args[0], args[1]
290
+ al, ag = (None if v=="None" else v for v in args[2:4])
291
+ return infer(pl, npl, al, ag, *args[4:])
292
+
293
+ go_btn.click(
294
+ _run,
295
+ inputs=[prompt, negative_prompt, adapter_l, adapter_g,
296
+ strength, delta_scale, sigma_scale, gpred_scale,
297
+ noise, gate_prob, use_anchor, steps, cfg_scale,
298
+ scheduler, width, height, seed],
299
+ outputs=[out_img, delta_l_i, gate_l_i, delta_g_i, gate_g_i,
300
+ stats_l, stats_g]
 
 
 
 
 
 
 
 
 
 
 
 
301
  )
 
302
  return demo
303
 
304
+
305
+ # ───────────────────────────────────────────────────────────────────────────
306
  if __name__ == "__main__":
307
+ create_interface().launch()