Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,432 +1,307 @@
|
|
1 |
-
|
2 |
-
import
|
|
|
|
|
|
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
import matplotlib.pyplot as plt
|
6 |
from PIL import Image
|
|
|
|
|
|
|
7 |
from transformers import T5Tokenizer, T5EncoderModel
|
8 |
-
from diffusers import
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
11 |
from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
12 |
from configs import T5_SHUNT_REPOS
|
13 |
-
import
|
|
|
|
|
|
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
# Available schedulers
|
21 |
SCHEDULERS = {
|
22 |
-
"DPM++ 2M":
|
23 |
-
"DDIM":
|
24 |
-
"Euler":
|
25 |
}
|
26 |
|
27 |
-
#
|
28 |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
|
29 |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
|
30 |
-
repo_l
|
31 |
-
repo_g
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
path = hf_hub_download(repo_id=repo, filename=filename)
|
40 |
-
|
41 |
-
|
42 |
-
tensors = {}
|
43 |
-
with safe_open(path, framework="pt", device="cpu") as f:
|
44 |
-
for key in f.keys():
|
45 |
-
tensors[key] = f.get_tensor(key)
|
46 |
model.load_state_dict(tensors)
|
47 |
return model.to(device)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
# Handle different input shapes
|
52 |
if isinstance(mat, torch.Tensor):
|
53 |
mat = mat.detach().cpu().numpy()
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
mat = mat.
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
mat = mat.mean(axis=0)
|
65 |
-
elif len(mat.shape) > 3:
|
66 |
-
# Flatten higher dimensions
|
67 |
-
mat = mat.reshape(-1, mat.shape[-1])
|
68 |
-
|
69 |
-
# Create figure with proper DPI
|
70 |
-
plt.figure(figsize=(8, 4), dpi=100)
|
71 |
-
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper", interpolation='nearest')
|
72 |
-
plt.title(title, fontsize=12, fontweight='bold')
|
73 |
-
plt.xlabel("Token Position")
|
74 |
-
plt.ylabel("Feature Dimension")
|
75 |
-
plt.colorbar(shrink=0.8)
|
76 |
plt.tight_layout()
|
77 |
-
|
78 |
-
# Convert to PIL Image
|
79 |
buf = io.BytesIO()
|
80 |
-
plt.savefig(buf, format="png"
|
81 |
-
buf.seek(0)
|
82 |
-
pil_image = Image.open(buf)
|
83 |
plt.close()
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
def
|
89 |
-
"""
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
tokens_g = pipe.tokenizer_2(
|
97 |
-
prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
98 |
-
).input_ids.to(device)
|
99 |
-
|
100 |
-
neg_tokens_l = pipe.tokenizer(
|
101 |
-
negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
102 |
-
).input_ids.to(device)
|
103 |
-
|
104 |
-
neg_tokens_g = pipe.tokenizer_2(
|
105 |
-
negative_prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
|
106 |
-
).input_ids.to(device)
|
107 |
-
|
108 |
with torch.no_grad():
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
158 |
if seed != -1:
|
159 |
-
torch.manual_seed(seed)
|
160 |
-
np.random.seed(seed)
|
161 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
else:
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
).input_ids.to(device)
|
173 |
-
|
174 |
-
with torch.no_grad():
|
175 |
-
t5_seq = t5_mod(t5_ids).last_hidden_state
|
176 |
-
|
177 |
-
# Get CLIP embeddings
|
178 |
-
clip_embeds = encode_sdxl_prompt(pipe, prompt, negative_prompt, device)
|
179 |
-
|
180 |
-
# Load and apply adapters
|
181 |
-
if(adapter_l_file == "t5-vit-l-14-dual_shunt_booru_13_000_000.safetensors" or adapter_l_file == "t5-vit-l-14-dual_shunt_booru_51_200_000.safetensors"):
|
182 |
-
config_l["heads"] = 4
|
183 |
-
else:
|
184 |
-
config_l["heads"] = 12
|
185 |
-
adapter_l = load_adapter(repo_l, adapter_l_file, config_l, device) if adapter_l_file else None
|
186 |
-
adapter_g = load_adapter(repo_g, adapter_g_file, config_g, device) if adapter_g_file else None
|
187 |
-
|
188 |
-
# Apply CLIP-L adapter
|
189 |
-
if adapter_l is not None:
|
190 |
-
with torch.no_grad():
|
191 |
-
# Run adapter forward pass
|
192 |
-
adapter_output = adapter_l(t5_seq.float(), clip_embeds["clip_l"].float())
|
193 |
-
|
194 |
-
# Unpack outputs (ensure correct number of outputs)
|
195 |
-
if len(adapter_output) == 8:
|
196 |
-
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_output
|
197 |
-
else:
|
198 |
-
# Handle different return formats
|
199 |
-
anchor_l = adapter_output[0]
|
200 |
-
delta_l = adapter_output[1]
|
201 |
-
log_sigma_l = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_l)
|
202 |
-
gate_l = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_l)
|
203 |
-
tau_l = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
|
204 |
-
g_pred_l = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
|
205 |
-
|
206 |
-
# Scale delta values
|
207 |
-
delta_l = delta_l * delta_scale
|
208 |
-
|
209 |
-
# Apply g_pred scaling to gate
|
210 |
-
gate_l = gate_l * g_pred_l * gpred_scale
|
211 |
-
|
212 |
-
# Apply gate scaling
|
213 |
-
gate_l_scaled = torch.sigmoid(gate_l) * gate_prob
|
214 |
-
|
215 |
-
# Compute final delta with strength and gate
|
216 |
-
delta_l_final = delta_l * strength * gate_l_scaled
|
217 |
-
|
218 |
-
# Apply delta to embeddings
|
219 |
-
clip_l_mod = clip_embeds["clip_l"] + delta_l_final.to(dtype)
|
220 |
-
|
221 |
-
# Apply sigma-based noise if specified
|
222 |
-
if sigma_scale > 0:
|
223 |
-
sigma_l = torch.exp(log_sigma_l * sigma_scale)
|
224 |
-
clip_l_mod += torch.randn_like(clip_l_mod) * sigma_l.to(dtype)
|
225 |
-
|
226 |
-
# Apply anchor mixing if enabled
|
227 |
-
if use_anchor:
|
228 |
-
clip_l_mod = clip_l_mod * (1 - gate_l_scaled.to(dtype)) + anchor_l.to(dtype) * gate_l_scaled.to(dtype)
|
229 |
-
|
230 |
-
# Add additional noise if specified
|
231 |
-
if noise > 0:
|
232 |
-
clip_l_mod += torch.randn_like(clip_l_mod) * noise
|
233 |
-
else:
|
234 |
-
clip_l_mod = clip_embeds["clip_l"]
|
235 |
-
delta_l_final = torch.zeros_like(clip_embeds["clip_l"])
|
236 |
-
gate_l_scaled = torch.zeros_like(clip_embeds["clip_l"])
|
237 |
-
g_pred_l = torch.tensor(0.0)
|
238 |
-
tau_l = torch.tensor(0.0)
|
239 |
-
|
240 |
-
# Apply CLIP-G adapter
|
241 |
-
if adapter_g is not None:
|
242 |
-
with torch.no_grad():
|
243 |
-
# Run adapter forward pass
|
244 |
-
adapter_output = adapter_g(t5_seq.float(), clip_embeds["clip_g"].float())
|
245 |
-
|
246 |
-
# Unpack outputs (ensure correct number of outputs)
|
247 |
-
if len(adapter_output) == 8:
|
248 |
-
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_output
|
249 |
-
else:
|
250 |
-
# Handle different return formats
|
251 |
-
anchor_g = adapter_output[0]
|
252 |
-
delta_g = adapter_output[1]
|
253 |
-
log_sigma_g = adapter_output[2] if len(adapter_output) > 2 else torch.zeros_like(delta_g)
|
254 |
-
gate_g = adapter_output[-1] if len(adapter_output) > 2 else torch.ones_like(delta_g)
|
255 |
-
tau_g = adapter_output[-2] if len(adapter_output) > 6 else torch.tensor(1.0)
|
256 |
-
g_pred_g = adapter_output[-3] if len(adapter_output) > 6 else torch.tensor(1.0)
|
257 |
-
|
258 |
-
# Scale delta values
|
259 |
-
delta_g = delta_g * delta_scale
|
260 |
-
|
261 |
-
# Apply g_pred scaling to gate
|
262 |
-
gate_g = gate_g * g_pred_g * gpred_scale
|
263 |
-
|
264 |
-
# Apply gate scaling
|
265 |
-
gate_g_scaled = torch.sigmoid(gate_g) * gate_prob
|
266 |
-
|
267 |
-
# Compute final delta with strength and gate
|
268 |
-
delta_g_final = delta_g * strength * gate_g_scaled
|
269 |
-
|
270 |
-
# Apply delta to embeddings
|
271 |
-
clip_g_mod = clip_embeds["clip_g"] + delta_g_final.to(dtype)
|
272 |
-
|
273 |
-
# Apply sigma-based noise if specified
|
274 |
-
if sigma_scale > 0:
|
275 |
-
sigma_g = torch.exp(log_sigma_g * sigma_scale)
|
276 |
-
clip_g_mod += torch.randn_like(clip_g_mod) * sigma_g.to(dtype)
|
277 |
-
|
278 |
-
# Apply anchor mixing if enabled
|
279 |
-
if use_anchor:
|
280 |
-
clip_g_mod = clip_g_mod * (1 - gate_g_scaled.to(dtype)) + anchor_g.to(dtype) * gate_g_scaled.to(dtype)
|
281 |
-
|
282 |
-
# Add additional noise if specified
|
283 |
-
if noise > 0:
|
284 |
-
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
285 |
else:
|
286 |
-
clip_g_mod =
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
tau_g = torch.tensor(0.0)
|
291 |
-
|
292 |
-
# Combine embeddings for SDXL: [CLIP-L(768) + CLIP-G(1280)] = 2048
|
293 |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
294 |
-
neg_embeds
|
295 |
-
|
296 |
-
#
|
297 |
-
image =
|
298 |
-
prompt_embeds=prompt_embeds,
|
299 |
-
|
300 |
-
|
301 |
-
negative_pooled_prompt_embeds=
|
302 |
-
num_inference_steps=steps,
|
303 |
-
|
304 |
-
width=width,
|
305 |
-
height=height,
|
306 |
-
num_images_per_prompt=1,
|
307 |
-
generator=generator
|
308 |
).images[0]
|
309 |
-
|
310 |
-
#
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
#
|
|
|
323 |
def create_interface():
|
324 |
-
with gr.Blocks(title="SDXL Dual
|
325 |
-
gr.Markdown("# π§ SDXL Dual
|
326 |
-
|
327 |
-
|
328 |
with gr.Row():
|
329 |
with gr.Column(scale=1):
|
330 |
-
|
331 |
-
gr.
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
)
|
344 |
-
|
345 |
-
|
346 |
-
gr.
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
)
|
353 |
-
adapter_g = gr.Dropdown(
|
354 |
-
choices=["None"] + clip_g_opts,
|
355 |
-
label="CLIP-G (1280d) Adapter",
|
356 |
-
value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
|
357 |
-
info="Choose adapter for CLIP-G embeddings"
|
358 |
-
)
|
359 |
-
|
360 |
-
# Controls
|
361 |
-
gr.Markdown("### ποΈ Adapter Controls")
|
362 |
-
strength = gr.Slider(0.0, 10.0, value=4.0, step=0.01, label="Adapter Strength")
|
363 |
-
delta_scale = gr.Slider(-15.0, 15.0, value=0.2, step=0.1, label="Delta Scale", info="Scales the delta values, recommended 1")
|
364 |
-
sigma_scale = gr.Slider(0, 15.0, value=0.1, step=0.1, label="Sigma Scale", info="Scales the noise variance, recommended 1")
|
365 |
-
gpred_scale = gr.Slider(0.0, 20.0, value=2.0, step=0.01, label="G-Pred Scale", info="Scales the gate prediction, recommended 2")
|
366 |
-
noise = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Noise Injection")
|
367 |
-
gate_prob = gr.Slider(0.0, 1.0, value=0.27, step=0.01, label="Gate Probability")
|
368 |
-
use_anchor = gr.Checkbox(label="Use Anchor Points", value=True)
|
369 |
-
|
370 |
-
# Generation Settings
|
371 |
-
gr.Markdown("### π¨ Generation Settings")
|
372 |
with gr.Row():
|
373 |
-
steps
|
374 |
-
cfg_scale
|
375 |
-
|
376 |
-
scheduler_name = gr.Dropdown(
|
377 |
-
choices=list(SCHEDULERS.keys()),
|
378 |
-
value="DPM++ 2M",
|
379 |
-
label="Scheduler"
|
380 |
-
)
|
381 |
-
|
382 |
with gr.Row():
|
383 |
-
width
|
384 |
-
height
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
with gr.Column(scale=1):
|
391 |
-
|
392 |
-
gr.Markdown("###
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
gr.
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
processed_args[2] = None if args[2] == "None" else args[2] # adapter_l
|
414 |
-
processed_args[3] = None if args[3] == "None" else args[3] # adapter_g
|
415 |
-
return infer(*processed_args)
|
416 |
-
|
417 |
-
generate_btn.click(
|
418 |
-
fn=run_generation,
|
419 |
-
inputs=[
|
420 |
-
prompt, negative_prompt, adapter_l, adapter_g, strength, delta_scale,
|
421 |
-
sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps, cfg_scale,
|
422 |
-
scheduler_name, width, height, seed
|
423 |
-
],
|
424 |
-
outputs=[output_image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l_text, stats_g_text]
|
425 |
)
|
426 |
-
|
427 |
return demo
|
428 |
|
429 |
-
|
|
|
430 |
if __name__ == "__main__":
|
431 |
-
|
432 |
-
demo.launch()
|
|
|
1 |
+
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
2 |
+
import io, os, json, math, random, warnings, gc, functools, hashlib
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, List, Optional
|
5 |
+
|
6 |
import gradio as gr
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
9 |
from PIL import Image
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
from transformers import T5Tokenizer, T5EncoderModel
|
14 |
+
from diffusers import (
|
15 |
+
StableDiffusionXLPipeline,
|
16 |
+
DDIMScheduler,
|
17 |
+
EulerDiscreteScheduler,
|
18 |
+
DPMSolverMultistepScheduler,
|
19 |
+
)
|
20 |
from huggingface_hub import hf_hub_download
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
# -------------------------------------------------------------------------
|
24 |
+
# local modules
|
25 |
from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
26 |
from configs import T5_SHUNT_REPOS
|
27 |
+
from embedding_manager import get_bank # β NEW
|
28 |
+
|
29 |
+
warnings.filterwarnings("ignore")
|
30 |
+
|
31 |
|
32 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
33 |
+
# GLOBALS
|
34 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
35 |
+
dtype = torch.float16
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
bank = get_bank() # shared singleton
|
38 |
+
|
39 |
+
_t5_tok: Optional[T5Tokenizer] = None
|
40 |
+
_t5_mod: Optional[T5EncoderModel] = None
|
41 |
+
_pipe: Optional[StableDiffusionXLPipeline] = None
|
42 |
|
|
|
43 |
SCHEDULERS = {
|
44 |
+
"DPM++ 2M": DPMSolverMultistepScheduler,
|
45 |
+
"DDIM": DDIMScheduler,
|
46 |
+
"Euler": EulerDiscreteScheduler,
|
47 |
}
|
48 |
|
49 |
+
# easy access to adapter repo metadata
|
50 |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
|
51 |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
|
52 |
+
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
|
53 |
+
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
|
54 |
+
conf_l = T5_SHUNT_REPOS["clip_l"]["config"]
|
55 |
+
conf_g = T5_SHUNT_REPOS["clip_g"]["config"]
|
56 |
+
|
57 |
+
|
58 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
59 |
+
# HELPERs
|
60 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
61 |
+
def _init_t5():
|
62 |
+
global _t5_tok, _t5_mod
|
63 |
+
if _t5_tok is None:
|
64 |
+
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
65 |
+
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
|
66 |
+
|
67 |
+
|
68 |
+
def _init_pipe():
|
69 |
+
global _pipe
|
70 |
+
if _pipe is None:
|
71 |
+
_pipe = StableDiffusionXLPipeline.from_pretrained(
|
72 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
73 |
+
torch_dtype=dtype,
|
74 |
+
use_safetensors=True,
|
75 |
+
variant="fp16",
|
76 |
+
).to(device)
|
77 |
+
_pipe.enable_xformers_memory_efficient_attention()
|
78 |
+
|
79 |
+
|
80 |
+
def load_adapter(repo: str, filename: str, cfg: dict):
|
81 |
+
"""load a TwoStreamShuntAdapter from HF Hub safetensors"""
|
82 |
path = hf_hub_download(repo_id=repo, filename=filename)
|
83 |
+
model = TwoStreamShuntAdapter(cfg).eval()
|
84 |
+
tensors = load_file(path)
|
|
|
|
|
|
|
|
|
85 |
model.load_state_dict(tensors)
|
86 |
return model.to(device)
|
87 |
|
88 |
+
|
89 |
+
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
|
|
|
90 |
if isinstance(mat, torch.Tensor):
|
91 |
mat = mat.detach().cpu().numpy()
|
92 |
+
|
93 |
+
if mat.ndim == 1:
|
94 |
+
mat = mat[None, :]
|
95 |
+
elif mat.ndim >= 3: # (B,T,D) β mean over B
|
96 |
+
mat = mat.mean(axis=0)
|
97 |
+
|
98 |
+
plt.figure(figsize=(8, 4), dpi=120)
|
99 |
+
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
|
100 |
+
plt.title(title)
|
101 |
+
plt.colorbar(shrink=0.7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
plt.tight_layout()
|
103 |
+
|
|
|
104 |
buf = io.BytesIO()
|
105 |
+
plt.savefig(buf, format="png")
|
|
|
|
|
106 |
plt.close()
|
107 |
+
buf.seek(0)
|
108 |
+
return np.array(Image.open(buf))
|
109 |
+
|
110 |
+
|
111 |
+
def encode_prompt_sd_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
|
112 |
+
"""Return CLIP-L, CLIP-G (and negative) embeddings from SDXL pipeline."""
|
113 |
+
tok_l = pipe.tokenizer(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
|
114 |
+
tok_g = pipe.tokenizer_2(prompt,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
|
115 |
+
ntok_l = pipe.tokenizer(negative, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
|
116 |
+
ntok_g = pipe.tokenizer_2(negative,max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
|
117 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
with torch.no_grad():
|
119 |
+
clip_l = pipe.text_encoder(tok_l)[0] # (1,77,768)
|
120 |
+
nclip_l= pipe.text_encoder(ntok_l)[0]
|
121 |
+
out_g = pipe.text_encoder_2(tok_g, output_hidden_states=False)
|
122 |
+
clip_g, pooled = out_g[1], out_g[0]
|
123 |
+
nout_g = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
|
124 |
+
nclip_g, npooled = nout_g[1], nout_g[0]
|
125 |
+
|
126 |
+
return {"clip_l": clip_l, "clip_g": clip_g,
|
127 |
+
"neg_l": nclip_l, "neg_g": nclip_g,
|
128 |
+
"pooled": pooled, "neg_pooled": npooled}
|
129 |
+
|
130 |
+
|
131 |
+
def adapter_forward(adapter, t5_seq, clip_seq, cfg):
|
132 |
+
with torch.no_grad():
|
133 |
+
out = adapter(t5_seq.float(), clip_seq.float())
|
134 |
+
# unify outputs
|
135 |
+
anchor, delta, log_sigma, *_, tau, g_pred, gate = (
|
136 |
+
out + (None,) * 8)[:8] # pad to length 8
|
137 |
+
delta = delta * cfg["delta_scale"]
|
138 |
+
gate = torch.sigmoid(gate * g_pred * cfg["gpred_scale"]) * cfg["gate_prob"]
|
139 |
+
final_delta = delta * cfg["strength"] * gate
|
140 |
+
mod = clip_seq + final_delta.to(dtype)
|
141 |
+
|
142 |
+
if cfg["sigma_scale"] > 0:
|
143 |
+
sigma = torch.exp(log_sigma * cfg["sigma_scale"])
|
144 |
+
mod += torch.randn_like(mod) * sigma.to(dtype)
|
145 |
+
if cfg["use_anchor"]:
|
146 |
+
mod = mod * (1 - gate) + anchor.to(dtype) * gate
|
147 |
+
if cfg["noise"] > 0:
|
148 |
+
mod += torch.randn_like(mod) * cfg["noise"]
|
149 |
+
return mod, final_delta, gate, g_pred, tau
|
150 |
+
|
151 |
+
|
152 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
153 |
+
# MAIN INFERENCE
|
154 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
155 |
+
def infer(prompt, negative_prompt,
|
156 |
+
adapter_l_file, adapter_g_file,
|
157 |
+
strength, delta_scale, sigma_scale,
|
158 |
+
gpred_scale, noise, gate_prob, use_anchor,
|
159 |
+
steps, cfg_scale, scheduler_name,
|
160 |
+
width, height, seed):
|
161 |
+
|
162 |
+
torch.cuda.empty_cache()
|
163 |
+
_init_t5(); _init_pipe()
|
164 |
+
|
165 |
+
# scheduler
|
166 |
+
if scheduler_name in SCHEDULERS:
|
167 |
+
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
|
168 |
+
|
169 |
+
# RNG
|
170 |
+
generator = None
|
171 |
if seed != -1:
|
|
|
|
|
172 |
generator = torch.Generator(device=device).manual_seed(seed)
|
173 |
+
torch.manual_seed(seed); np.random.seed(seed)
|
174 |
+
|
175 |
+
# T5 embeddings (semantic guidance)
|
176 |
+
t5_ids = _t5_tok(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
|
177 |
+
t5_seq = _t5_mod(t5_ids).last_hidden_state # (1,77,768)
|
178 |
+
|
179 |
+
# CLIP embeddings from SDXL
|
180 |
+
embeds = encode_prompt_sd_xl(_pipe, prompt, negative_prompt)
|
181 |
+
|
182 |
+
# ------------------------------------------------------------------
|
183 |
+
# LOAD adapters (if any)
|
184 |
+
cfg_common = dict(
|
185 |
+
strength=strength, delta_scale=delta_scale, sigma_scale=sigma_scale,
|
186 |
+
gpred_scale=gpred_scale, noise=noise, gate_prob=gate_prob,
|
187 |
+
use_anchor=use_anchor,
|
188 |
+
)
|
189 |
+
|
190 |
+
# CLIP-L
|
191 |
+
if adapter_l_file and adapter_l_file != "None":
|
192 |
+
cfg_l = conf_l.copy(); cfg_l.update(cfg_common)
|
193 |
+
if "booru" in adapter_l_file: cfg_l["heads"] = 4
|
194 |
+
adapter_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
|
195 |
+
clip_l_mod, delta_l, gate_l, g_pred_l, tau_l = adapter_forward(
|
196 |
+
adapter_l, t5_seq, embeds["clip_l"], cfg_l)
|
197 |
else:
|
198 |
+
clip_l_mod = embeds["clip_l"]; delta_l = torch.zeros_like(clip_l_mod)
|
199 |
+
gate_l = torch.zeros_like(clip_l_mod[..., :1]); g_pred_l = tau_l = torch.tensor(0.)
|
200 |
+
|
201 |
+
# CLIP-G
|
202 |
+
if adapter_g_file and adapter_g_file != "None":
|
203 |
+
cfg_g = conf_g.copy(); cfg_g.update(cfg_common)
|
204 |
+
adapter_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
|
205 |
+
clip_g_mod, delta_g, gate_g, g_pred_g, tau_g = adapter_forward(
|
206 |
+
adapter_g, t5_seq, embeds["clip_g"], cfg_g)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
else:
|
208 |
+
clip_g_mod = embeds["clip_g"]; delta_g = torch.zeros_like(clip_g_mod)
|
209 |
+
gate_g = torch.zeros_like(clip_g_mod[..., :1]); g_pred_g = tau_g = torch.tensor(0.)
|
210 |
+
|
211 |
+
# concatenate for SDXL
|
|
|
|
|
|
|
212 |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
213 |
+
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
|
214 |
+
|
215 |
+
# SDXL generation
|
216 |
+
image = _pipe(
|
217 |
+
prompt_embeds = prompt_embeds,
|
218 |
+
negative_prompt_embeds = neg_embeds,
|
219 |
+
pooled_prompt_embeds = embeds["pooled"],
|
220 |
+
negative_pooled_prompt_embeds = embeds["neg_pooled"],
|
221 |
+
num_inference_steps=steps, guidance_scale=cfg_scale,
|
222 |
+
width=width, height=height, generator=generator
|
|
|
|
|
|
|
|
|
223 |
).images[0]
|
224 |
+
|
225 |
+
# viz
|
226 |
+
delta_l_img = plot_heat(delta_l.squeeze(), "Ξ CLIP-L")
|
227 |
+
gate_l_img = plot_heat(gate_l.squeeze().mean(-1, keepdims=True), "Gate L")
|
228 |
+
delta_g_img = plot_heat(delta_g.squeeze(), "Ξ CLIP-G")
|
229 |
+
gate_g_img = plot_heat(gate_g.squeeze().mean(-1, keepdims=True), "Gate G")
|
230 |
+
|
231 |
+
stats_l = f"g_pred_L={g_pred_l.item():.3f} | Ο_L={tau_l.item():.3f}"
|
232 |
+
stats_g = f"g_pred_G={g_pred_g.item():.3f} | Ο_G={tau_g.item():.3f}"
|
233 |
+
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
|
234 |
+
|
235 |
+
|
236 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
237 |
+
# GRADIO UI
|
238 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
239 |
def create_interface():
|
240 |
+
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
|
241 |
+
gr.Markdown("# π§ SDXL Dual-Shunt Tester")
|
242 |
+
|
|
|
243 |
with gr.Row():
|
244 |
with gr.Column(scale=1):
|
245 |
+
gr.Markdown("### Prompts")
|
246 |
+
prompt = gr.Textbox(label="Prompt", lines=3,
|
247 |
+
value="a futuristic control station with holographic displays")
|
248 |
+
negative_prompt = gr.Textbox(label="Negative", lines=2,
|
249 |
+
value="blurry, low quality, distorted")
|
250 |
+
|
251 |
+
gr.Markdown("### Adapters")
|
252 |
+
adapter_l = gr.Dropdown(["None"]+clip_l_opts, value="t5-vit-l-14-dual_shunt_caption.safetensors",
|
253 |
+
label="CLIP-L Adapter")
|
254 |
+
adapter_g = gr.Dropdown(["None"]+clip_g_opts, value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
|
255 |
+
label="CLIP-G Adapter")
|
256 |
+
|
257 |
+
gr.Markdown("### Adapter Controls")
|
258 |
+
strength = gr.Slider(0, 10, 4.0, 0.01, label="Strength")
|
259 |
+
delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ scale")
|
260 |
+
sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Ο scale")
|
261 |
+
gpred_scale = gr.Slider(0, 20, 2.0, 0.01, label="g_pred scale")
|
262 |
+
noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
|
263 |
+
gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
|
264 |
+
use_anchor = gr.Checkbox(True, label="Use anchor mix")
|
265 |
+
|
266 |
+
gr.Markdown("### Generation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
with gr.Row():
|
268 |
+
steps = gr.Slider(1, 50, 20, 1, label="Steps")
|
269 |
+
cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
|
270 |
+
scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="DPM++ 2M", label="Scheduler")
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
with gr.Row():
|
272 |
+
width = gr.Slider(512, 1536, 1024, 64, label="Width")
|
273 |
+
height = gr.Slider(512, 1536, 1024, 64, label="Height")
|
274 |
+
seed = gr.Number(-1, label="Seed (-1=random)")
|
275 |
+
|
276 |
+
go_btn = gr.Button("π Generate", variant="primary")
|
277 |
+
|
|
|
278 |
with gr.Column(scale=1):
|
279 |
+
out_img = gr.Image(label="Result", height=400)
|
280 |
+
gr.Markdown("### Adapter Diagnostics")
|
281 |
+
delta_l_i = gr.Image(label="Ξ L", height=180)
|
282 |
+
gate_l_i = gr.Image(label="Gate L", height=180)
|
283 |
+
delta_g_i = gr.Image(label="Ξ G", height=180)
|
284 |
+
gate_g_i = gr.Image(label="Gate G", height=180)
|
285 |
+
stats_l = gr.Textbox(label="Stats L", interactive=False)
|
286 |
+
stats_g = gr.Textbox(label="Stats G", interactive=False)
|
287 |
+
|
288 |
+
def _run(*args):
|
289 |
+
pl , npl = args[0], args[1]
|
290 |
+
al, ag = (None if v=="None" else v for v in args[2:4])
|
291 |
+
return infer(pl, npl, al, ag, *args[4:])
|
292 |
+
|
293 |
+
go_btn.click(
|
294 |
+
_run,
|
295 |
+
inputs=[prompt, negative_prompt, adapter_l, adapter_g,
|
296 |
+
strength, delta_scale, sigma_scale, gpred_scale,
|
297 |
+
noise, gate_prob, use_anchor, steps, cfg_scale,
|
298 |
+
scheduler, width, height, seed],
|
299 |
+
outputs=[out_img, delta_l_i, gate_l_i, delta_g_i, gate_g_i,
|
300 |
+
stats_l, stats_g]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
)
|
|
|
302 |
return demo
|
303 |
|
304 |
+
|
305 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
306 |
if __name__ == "__main__":
|
307 |
+
create_interface().launch()
|
|