Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,69 +1,57 @@
|
|
1 |
-
# app.py
|
2 |
-
import io,
|
3 |
from pathlib import Path
|
4 |
-
from typing import Dict, List, Optional
|
5 |
|
6 |
import gradio as gr
|
7 |
-
import
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
from PIL import Image
|
10 |
-
|
11 |
-
import torch
|
12 |
-
import torch.nn.functional as F
|
13 |
from transformers import T5Tokenizer, T5EncoderModel
|
14 |
from diffusers import (
|
15 |
StableDiffusionXLPipeline,
|
16 |
-
DDIMScheduler,
|
17 |
-
EulerDiscreteScheduler,
|
18 |
-
DPMSolverMultistepScheduler,
|
19 |
)
|
20 |
from huggingface_hub import hf_hub_download
|
21 |
from safetensors.torch import load_file
|
22 |
|
23 |
-
# -------------------------------------------------------------------------
|
24 |
# local modules
|
25 |
from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
26 |
-
from
|
|
|
27 |
from configs import T5_SHUNT_REPOS
|
28 |
-
from embedding_manager import get_bank # β NEW
|
29 |
|
30 |
warnings.filterwarnings("ignore")
|
31 |
|
|
|
|
|
|
|
32 |
|
33 |
-
#
|
34 |
-
# GLOBALS
|
35 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
36 |
-
dtype = torch.float16
|
37 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
-
bank = get_bank() # shared singleton
|
39 |
|
40 |
-
_t5_tok: Optional[T5Tokenizer]
|
41 |
-
_t5_mod: Optional[T5EncoderModel]
|
42 |
-
_pipe:
|
43 |
|
44 |
SCHEDULERS = {
|
45 |
-
"DPM++ 2M":
|
46 |
-
"DDIM":
|
47 |
-
"Euler":
|
48 |
}
|
49 |
|
50 |
-
#
|
51 |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
|
52 |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
|
53 |
-
repo_l
|
54 |
-
repo_g
|
55 |
-
conf_l = T5_SHUNT_REPOS["clip_l"]["config"]
|
56 |
-
conf_g = T5_SHUNT_REPOS["clip_g"]["config"]
|
57 |
|
58 |
|
59 |
-
#
|
60 |
-
# HELPERs
|
61 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
62 |
def _init_t5():
|
63 |
global _t5_tok, _t5_mod
|
64 |
if _t5_tok is None:
|
65 |
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
66 |
-
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base")
|
|
|
67 |
|
68 |
|
69 |
def _init_pipe():
|
@@ -71,124 +59,82 @@ def _init_pipe():
|
|
71 |
if _pipe is None:
|
72 |
_pipe = StableDiffusionXLPipeline.from_pretrained(
|
73 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
74 |
-
torch_dtype=dtype,
|
75 |
-
use_safetensors=True,
|
76 |
-
variant="fp16",
|
77 |
).to(device)
|
78 |
_pipe.enable_xformers_memory_efficient_attention()
|
79 |
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
model.load_state_dict(
|
87 |
return model.to(device)
|
88 |
|
89 |
|
90 |
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
|
91 |
if isinstance(mat, torch.Tensor):
|
92 |
mat = mat.detach().cpu().numpy()
|
93 |
-
|
94 |
if mat.ndim == 1:
|
95 |
mat = mat[None, :]
|
96 |
-
elif mat.ndim >= 3:
|
97 |
mat = mat.mean(axis=0)
|
98 |
|
99 |
-
plt.figure(figsize=(
|
100 |
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
|
101 |
-
plt.title(title)
|
102 |
plt.colorbar(shrink=0.7)
|
103 |
plt.tight_layout()
|
104 |
|
105 |
buf = io.BytesIO()
|
106 |
-
plt.savefig(buf, format="png")
|
107 |
-
plt.close()
|
108 |
-
buf.seek(0)
|
109 |
return np.array(Image.open(buf))
|
110 |
|
111 |
|
112 |
-
def
|
113 |
-
|
114 |
-
|
115 |
-
tok_g
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
118 |
|
119 |
with torch.no_grad():
|
120 |
-
clip_l
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
126 |
|
127 |
return {"clip_l": clip_l, "clip_g": clip_g,
|
128 |
-
"neg_l":
|
129 |
-
"pooled":
|
130 |
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
gate = torch.sigmoid(gate * g_pred * cfg["gpred_scale"]) * cfg["gate_prob"]
|
140 |
-
final_delta = delta * cfg["strength"] * gate
|
141 |
-
mod = clip_seq + final_delta.to(dtype)
|
142 |
-
|
143 |
-
if cfg["sigma_scale"] > 0:
|
144 |
-
sigma = torch.exp(log_sigma * cfg["sigma_scale"])
|
145 |
-
mod += torch.randn_like(mod) * sigma.to(dtype)
|
146 |
-
if cfg["use_anchor"]:
|
147 |
-
mod = mod * (1 - gate) + anchor.to(dtype) * gate
|
148 |
-
if cfg["noise"] > 0:
|
149 |
-
mod += torch.randn_like(mod) * cfg["noise"]
|
150 |
-
return mod, final_delta, gate, g_pred, tau
|
151 |
-
|
152 |
-
|
153 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
154 |
-
# MAIN INFERENCE
|
155 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
156 |
-
def infer(prompt, negative_prompt,
|
157 |
-
adapter_l_file, adapter_g_file,
|
158 |
-
strength, delta_scale, sigma_scale,
|
159 |
-
gpred_scale, noise, gate_prob, use_anchor,
|
160 |
-
steps, cfg_scale, scheduler_name,
|
161 |
-
width, height, seed):
|
162 |
|
163 |
torch.cuda.empty_cache()
|
164 |
_init_t5(); _init_pipe()
|
165 |
|
166 |
-
# scheduler
|
167 |
if scheduler_name in SCHEDULERS:
|
168 |
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
if seed != -1:
|
173 |
-
generator = torch.Generator(device=device).manual_seed(seed)
|
174 |
-
torch.manual_seed(seed); np.random.seed(seed)
|
175 |
-
|
176 |
-
# T5 embeddings (semantic guidance)
|
177 |
-
t5_ids = _t5_tok(prompt, max_length=77, truncation=True, padding="max_length", return_tensors="pt").input_ids.to(device)
|
178 |
-
t5_seq = _t5_mod(t5_ids).last_hidden_state # (1,77,768)
|
179 |
-
|
180 |
-
# CLIP embeddings from SDXL
|
181 |
-
embeds = encode_prompt_sd_xl(_pipe, prompt, negative_prompt)
|
182 |
-
|
183 |
-
# ------------------------------------------------------------------
|
184 |
-
# LOAD adapters (if any)
|
185 |
-
cfg_common = dict(
|
186 |
-
strength=strength, delta_scale=delta_scale, sigma_scale=sigma_scale,
|
187 |
-
gpred_scale=gpred_scale, noise=noise, gate_prob=gate_prob,
|
188 |
-
use_anchor=use_anchor,
|
189 |
-
)
|
190 |
|
191 |
-
#
|
192 |
cfg_shift = ShiftConfig(
|
193 |
prompt = prompt,
|
194 |
seed = seed,
|
@@ -200,66 +146,76 @@ def infer(prompt, negative_prompt,
|
|
200 |
use_anchor = use_anchor,
|
201 |
guidance_scale = gpred_scale,
|
202 |
)
|
203 |
-
|
204 |
-
#
|
205 |
t5_seq = ConditioningShifter.extract_encoder_embeddings(
|
206 |
{"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
|
207 |
device, cfg_shift
|
208 |
)
|
209 |
-
|
210 |
-
#
|
211 |
-
|
|
|
|
|
|
|
212 |
if adapter_l_file and adapter_l_file != "None":
|
213 |
ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
|
214 |
outputs.append(ConditioningShifter.run_adapter(
|
215 |
ada_l, t5_seq, embeds["clip_l"],
|
216 |
cfg_shift.guidance_scale, "clip_l", (0, 768)))
|
217 |
-
|
218 |
if adapter_g_file and adapter_g_file != "None":
|
219 |
ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
|
220 |
outputs.append(ConditioningShifter.run_adapter(
|
221 |
ada_g, t5_seq, embeds["clip_g"],
|
222 |
cfg_shift.guidance_scale, "clip_g", (768, 2048)))
|
223 |
-
|
224 |
-
#
|
225 |
clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
|
|
|
|
|
|
|
|
|
|
|
226 |
for out in outputs:
|
227 |
-
|
228 |
-
mod
|
229 |
if out.adapter_type == "clip_l":
|
230 |
clip_l_mod = mod
|
231 |
else:
|
232 |
clip_g_mod = mod
|
|
|
|
|
233 |
|
234 |
-
|
235 |
-
# concatenate for SDXL
|
236 |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
237 |
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
|
238 |
|
239 |
-
# SDXL generation
|
240 |
image = _pipe(
|
241 |
-
prompt_embeds
|
242 |
-
negative_prompt_embeds
|
243 |
-
pooled_prompt_embeds
|
244 |
negative_pooled_prompt_embeds = embeds["neg_pooled"],
|
245 |
-
num_inference_steps=steps,
|
246 |
-
|
|
|
247 |
).images[0]
|
248 |
|
249 |
-
#
|
250 |
-
delta_l_img = plot_heat(
|
251 |
-
gate_l_img = plot_heat(
|
252 |
-
delta_g_img = plot_heat(
|
253 |
-
gate_g_img = plot_heat(
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
-
stats_l = f"g_pred_L={g_pred_l.item():.3f} | Ο_L={tau_l.item():.3f}"
|
256 |
-
stats_g = f"g_pred_G={g_pred_g.item():.3f} | Ο_G={tau_g.item():.3f}"
|
257 |
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
|
258 |
|
259 |
|
260 |
-
#
|
261 |
-
# GRADIO UI
|
262 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
263 |
def create_interface():
|
264 |
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
|
265 |
gr.Markdown("# π§ SDXL Dual-Shunt Tester")
|
@@ -267,65 +223,65 @@ def create_interface():
|
|
267 |
with gr.Row():
|
268 |
with gr.Column(scale=1):
|
269 |
gr.Markdown("### Prompts")
|
270 |
-
prompt
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
|
275 |
gr.Markdown("### Adapters")
|
276 |
-
adapter_l = gr.Dropdown(["None"]+clip_l_opts,
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
280 |
|
281 |
gr.Markdown("### Adapter Controls")
|
282 |
-
strength
|
283 |
-
delta_scale
|
284 |
-
sigma_scale
|
285 |
-
gpred_scale
|
286 |
-
noise
|
287 |
-
gate_prob
|
288 |
-
use_anchor
|
289 |
|
290 |
gr.Markdown("### Generation")
|
291 |
with gr.Row():
|
292 |
-
steps
|
293 |
-
cfg_scale
|
294 |
-
scheduler
|
|
|
295 |
with gr.Row():
|
296 |
-
width
|
297 |
-
height
|
298 |
-
seed
|
299 |
|
300 |
-
|
301 |
|
302 |
with gr.Column(scale=1):
|
303 |
-
out_img
|
304 |
-
gr.Markdown("###
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
stats_l
|
310 |
-
stats_g
|
311 |
|
312 |
def _run(*args):
|
313 |
-
pl
|
314 |
-
al, ag
|
315 |
return infer(pl, npl, al, ag, *args[4:])
|
316 |
|
317 |
-
|
318 |
-
_run,
|
319 |
-
inputs=[prompt,
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
outputs=[out_img, delta_l_i, gate_l_i, delta_g_i, gate_g_i,
|
324 |
-
stats_l, stats_g]
|
325 |
)
|
326 |
return demo
|
327 |
|
328 |
|
329 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
330 |
if __name__ == "__main__":
|
331 |
create_interface().launch()
|
|
|
1 |
+
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
2 |
+
import io, warnings, numpy as np, matplotlib.pyplot as plt
|
3 |
from pathlib import Path
|
4 |
+
from typing import Dict, List, Optional, Tuple
|
5 |
|
6 |
import gradio as gr
|
7 |
+
import torch, torch.nn.functional as F
|
|
|
8 |
from PIL import Image
|
|
|
|
|
|
|
9 |
from transformers import T5Tokenizer, T5EncoderModel
|
10 |
from diffusers import (
|
11 |
StableDiffusionXLPipeline,
|
12 |
+
DDIMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler,
|
|
|
|
|
13 |
)
|
14 |
from huggingface_hub import hf_hub_download
|
15 |
from safetensors.torch import load_file
|
16 |
|
|
|
17 |
# local modules
|
18 |
from two_stream_shunt_adapter import TwoStreamShuntAdapter
|
19 |
+
from conditioning_shifter import ConditioningShifter, ShiftConfig, AdapterOutput
|
20 |
+
from embedding_manager import get_bank
|
21 |
from configs import T5_SHUNT_REPOS
|
|
|
22 |
|
23 |
warnings.filterwarnings("ignore")
|
24 |
|
25 |
+
# βββ GLOBALS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
26 |
+
dtype = torch.float16
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
|
29 |
+
_bank = get_bank() # singleton β optional caching
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
_t5_tok: Optional[T5Tokenizer] = None
|
32 |
+
_t5_mod: Optional[T5EncoderModel] = None
|
33 |
+
_pipe : Optional[StableDiffusionXLPipeline] = None
|
34 |
|
35 |
SCHEDULERS = {
|
36 |
+
"DPM++ 2M": DPMSolverMultistepScheduler,
|
37 |
+
"DDIM": DDIMScheduler,
|
38 |
+
"Euler": EulerDiscreteScheduler,
|
39 |
}
|
40 |
|
41 |
+
# adapter-meta from configs.py
|
42 |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
|
43 |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
|
44 |
+
repo_l, conf_l = T5_SHUNT_REPOS["clip_l"]["repo"], T5_SHUNT_REPOS["clip_l"]["config"]
|
45 |
+
repo_g, conf_g = T5_SHUNT_REPOS["clip_g"]["repo"], T5_SHUNT_REPOS["clip_g"]["config"]
|
|
|
|
|
46 |
|
47 |
|
48 |
+
# βββ INITIALISERS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
49 |
def _init_t5():
|
50 |
global _t5_tok, _t5_mod
|
51 |
if _t5_tok is None:
|
52 |
_t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
53 |
+
_t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base") \
|
54 |
+
.to(device).eval()
|
55 |
|
56 |
|
57 |
def _init_pipe():
|
|
|
59 |
if _pipe is None:
|
60 |
_pipe = StableDiffusionXLPipeline.from_pretrained(
|
61 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
62 |
+
torch_dtype=dtype, variant="fp16", use_safetensors=True
|
|
|
|
|
63 |
).to(device)
|
64 |
_pipe.enable_xformers_memory_efficient_attention()
|
65 |
|
66 |
|
67 |
+
# βββ HELPERS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
68 |
+
def load_adapter(repo: str, filename: str, cfg: dict,
|
69 |
+
device: torch.device) -> TwoStreamShuntAdapter:
|
70 |
+
path = hf_hub_download(repo_id=repo, filename=filename)
|
71 |
+
model = TwoStreamShuntAdapter(cfg).eval()
|
72 |
+
model.load_state_dict(load_file(path))
|
73 |
return model.to(device)
|
74 |
|
75 |
|
76 |
def plot_heat(mat: torch.Tensor | np.ndarray, title: str) -> np.ndarray:
|
77 |
if isinstance(mat, torch.Tensor):
|
78 |
mat = mat.detach().cpu().numpy()
|
|
|
79 |
if mat.ndim == 1:
|
80 |
mat = mat[None, :]
|
81 |
+
elif mat.ndim >= 3:
|
82 |
mat = mat.mean(axis=0)
|
83 |
|
84 |
+
plt.figure(figsize=(7, 3.3), dpi=110)
|
85 |
plt.imshow(mat, aspect="auto", cmap="RdBu_r", origin="upper")
|
86 |
+
plt.title(title, fontsize=10)
|
87 |
plt.colorbar(shrink=0.7)
|
88 |
plt.tight_layout()
|
89 |
|
90 |
buf = io.BytesIO()
|
91 |
+
plt.savefig(buf, format="png", bbox_inches="tight")
|
92 |
+
plt.close(); buf.seek(0)
|
|
|
93 |
return np.array(Image.open(buf))
|
94 |
|
95 |
|
96 |
+
def encode_prompt_xl(pipe, prompt: str, negative: str) -> Dict[str, torch.Tensor]:
|
97 |
+
tok_l = pipe.tokenizer (prompt, max_length=77, truncation=True,
|
98 |
+
padding="max_length", return_tensors="pt").input_ids.to(device)
|
99 |
+
tok_g = pipe.tokenizer_2(prompt, max_length=77, truncation=True,
|
100 |
+
padding="max_length", return_tensors="pt").input_ids.to(device)
|
101 |
+
ntok_l = pipe.tokenizer (negative,max_length=77, truncation=True,
|
102 |
+
padding="max_length", return_tensors="pt").input_ids.to(device)
|
103 |
+
ntok_g = pipe.tokenizer_2(negative,max_length=77, truncation=True,
|
104 |
+
padding="max_length", return_tensors="pt").input_ids.to(device)
|
105 |
|
106 |
with torch.no_grad():
|
107 |
+
clip_l = pipe.text_encoder(tok_l)[0]
|
108 |
+
neg_clip_l = pipe.text_encoder(ntok_l)[0]
|
109 |
+
|
110 |
+
g_out = pipe.text_encoder_2(tok_g, output_hidden_states=False)
|
111 |
+
clip_g, pl = g_out[1], g_out[0]
|
112 |
+
ng_out = pipe.text_encoder_2(ntok_g, output_hidden_states=False)
|
113 |
+
neg_clip_g, npl = ng_out[1], ng_out[0]
|
114 |
|
115 |
return {"clip_l": clip_l, "clip_g": clip_g,
|
116 |
+
"neg_l": neg_clip_l, "neg_g": neg_clip_g,
|
117 |
+
"pooled": pl, "neg_pooled": npl}
|
118 |
|
119 |
|
120 |
+
# βββ INFERENCE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
121 |
+
def infer(prompt: str, negative_prompt: str,
|
122 |
+
adapter_l_file: str, adapter_g_file: str,
|
123 |
+
strength: float, delta_scale: float, sigma_scale: float,
|
124 |
+
gpred_scale: float, noise: float, gate_prob: float, use_anchor: bool,
|
125 |
+
steps: int, cfg_scale: float, scheduler_name: str,
|
126 |
+
width: int, height: int, seed: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
torch.cuda.empty_cache()
|
129 |
_init_t5(); _init_pipe()
|
130 |
|
|
|
131 |
if scheduler_name in SCHEDULERS:
|
132 |
_pipe.scheduler = SCHEDULERS[scheduler_name].from_config(_pipe.scheduler.config)
|
133 |
|
134 |
+
generator = (torch.Generator(device=device).manual_seed(seed)
|
135 |
+
if seed != -1 else None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
+
# build ShiftConfig (one per request)
|
138 |
cfg_shift = ShiftConfig(
|
139 |
prompt = prompt,
|
140 |
seed = seed,
|
|
|
146 |
use_anchor = use_anchor,
|
147 |
guidance_scale = gpred_scale,
|
148 |
)
|
149 |
+
|
150 |
+
# encoder (T5) embeddings
|
151 |
t5_seq = ConditioningShifter.extract_encoder_embeddings(
|
152 |
{"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
|
153 |
device, cfg_shift
|
154 |
)
|
155 |
+
|
156 |
+
# CLIP embeddings
|
157 |
+
embeds = encode_prompt_xl(_pipe, prompt, negative_prompt)
|
158 |
+
|
159 |
+
# run adapters --------------------------------------------------------
|
160 |
+
outputs: List[AdapterOutput] = []
|
161 |
if adapter_l_file and adapter_l_file != "None":
|
162 |
ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
|
163 |
outputs.append(ConditioningShifter.run_adapter(
|
164 |
ada_l, t5_seq, embeds["clip_l"],
|
165 |
cfg_shift.guidance_scale, "clip_l", (0, 768)))
|
166 |
+
|
167 |
if adapter_g_file and adapter_g_file != "None":
|
168 |
ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
|
169 |
outputs.append(ConditioningShifter.run_adapter(
|
170 |
ada_g, t5_seq, embeds["clip_g"],
|
171 |
cfg_shift.guidance_scale, "clip_g", (768, 2048)))
|
172 |
+
|
173 |
+
# apply modifications -------------------------------------------------
|
174 |
clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
|
175 |
+
delta_viz = {"clip_l": torch.zeros_like(clip_l_mod),
|
176 |
+
"clip_g": torch.zeros_like(clip_g_mod)}
|
177 |
+
gate_viz = {"clip_l": torch.zeros_like(clip_l_mod[..., :1]),
|
178 |
+
"clip_g": torch.zeros_like(clip_g_mod[..., :1])}
|
179 |
+
|
180 |
for out in outputs:
|
181 |
+
target = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
|
182 |
+
mod = ConditioningShifter.apply_modifications(target, [out], cfg_shift)
|
183 |
if out.adapter_type == "clip_l":
|
184 |
clip_l_mod = mod
|
185 |
else:
|
186 |
clip_g_mod = mod
|
187 |
+
delta_viz[out.adapter_type] = out.delta.detach()
|
188 |
+
gate_viz [out.adapter_type] = out.gate.detach()
|
189 |
|
190 |
+
# prepare for SDXL ----------------------------------------------------
|
|
|
191 |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
192 |
neg_embeds = torch.cat([embeds["neg_l"], embeds["neg_g"]], dim=-1)
|
193 |
|
|
|
194 |
image = _pipe(
|
195 |
+
prompt_embeds = prompt_embeds,
|
196 |
+
negative_prompt_embeds = neg_embeds,
|
197 |
+
pooled_prompt_embeds = embeds["pooled"],
|
198 |
negative_pooled_prompt_embeds = embeds["neg_pooled"],
|
199 |
+
num_inference_steps = steps,
|
200 |
+
guidance_scale = cfg_scale,
|
201 |
+
width = width, height = height, generator = generator
|
202 |
).images[0]
|
203 |
|
204 |
+
# diagnostics ---------------------------------------------------------
|
205 |
+
delta_l_img = plot_heat(delta_viz["clip_l"].squeeze(), "Ξ CLIP-L")
|
206 |
+
gate_l_img = plot_heat(gate_viz ["clip_l"].squeeze().mean(-1, keepdims=True), "Gate L")
|
207 |
+
delta_g_img = plot_heat(delta_viz["clip_g"].squeeze(), "Ξ CLIP-G")
|
208 |
+
gate_g_img = plot_heat(gate_viz ["clip_g"].squeeze().mean(-1, keepdims=True), "Gate G")
|
209 |
+
|
210 |
+
stats_l = (f"ΟΜ_L = {outputs[0].tau.mean().item():.3f}"
|
211 |
+
if outputs and outputs[0].adapter_type == "clip_l" else "-")
|
212 |
+
stats_g = (f"ΟΜ_G = {outputs[-1].tau.mean().item():.3f}"
|
213 |
+
if len(outputs) > 1 and outputs[-1].adapter_type == "clip_g" else "-")
|
214 |
|
|
|
|
|
215 |
return image, delta_l_img, gate_l_img, delta_g_img, gate_g_img, stats_l, stats_g
|
216 |
|
217 |
|
218 |
+
# βββ GRADIO UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
219 |
def create_interface():
|
220 |
with gr.Blocks(title="SDXL Dual-Shunt Tester", theme=gr.themes.Soft()) as demo:
|
221 |
gr.Markdown("# π§ SDXL Dual-Shunt Tester")
|
|
|
223 |
with gr.Row():
|
224 |
with gr.Column(scale=1):
|
225 |
gr.Markdown("### Prompts")
|
226 |
+
prompt = gr.Textbox(label="Prompt", lines=3,
|
227 |
+
value="a futuristic control station with holographic displays")
|
228 |
+
negative = gr.Textbox(label="Negative", lines=2,
|
229 |
+
value="blurry, low quality, distorted")
|
230 |
|
231 |
gr.Markdown("### Adapters")
|
232 |
+
adapter_l = gr.Dropdown(["None"] + clip_l_opts,
|
233 |
+
value="t5-vit-l-14-dual_shunt_caption.safetensors",
|
234 |
+
label="CLIP-L Adapter")
|
235 |
+
adapter_g = gr.Dropdown(["None"] + clip_g_opts,
|
236 |
+
value="dual_shunt_omega_no_caption_noised_e1_step_10000.safetensors",
|
237 |
+
label="CLIP-G Adapter")
|
238 |
|
239 |
gr.Markdown("### Adapter Controls")
|
240 |
+
strength = gr.Slider(0, 10, 4.0, 0.05, label="Strength")
|
241 |
+
delta_scale = gr.Slider(-15, 15, 0.2, 0.1, label="Ξ scale")
|
242 |
+
sigma_scale = gr.Slider(0, 15, 0.1, 0.1, label="Ο scale")
|
243 |
+
gpred_scale = gr.Slider(0, 20, 2.0, 0.05, label="Guidance scale")
|
244 |
+
noise = gr.Slider(0, 1, 0.55, 0.01, label="Extra noise")
|
245 |
+
gate_prob = gr.Slider(0, 1, 0.27, 0.01, label="Gate prob")
|
246 |
+
use_anchor = gr.Checkbox(True, label="Use anchor mix")
|
247 |
|
248 |
gr.Markdown("### Generation")
|
249 |
with gr.Row():
|
250 |
+
steps = gr.Slider(1, 50, 20, 1, label="Steps")
|
251 |
+
cfg_scale = gr.Slider(1, 15, 7.5, 0.1, label="CFG")
|
252 |
+
scheduler = gr.Dropdown(list(SCHEDULERS.keys()),
|
253 |
+
value="DPM++ 2M", label="Scheduler")
|
254 |
with gr.Row():
|
255 |
+
width = gr.Slider(512, 1536, 1024, 64, label="Width")
|
256 |
+
height = gr.Slider(512, 1536, 1024, 64, label="Height")
|
257 |
+
seed = gr.Number(-1, label="Seed (-1 β random)", precision=0)
|
258 |
|
259 |
+
run_btn = gr.Button("π Generate", variant="primary")
|
260 |
|
261 |
with gr.Column(scale=1):
|
262 |
+
out_img = gr.Image(label="Result", height=400)
|
263 |
+
gr.Markdown("### Diagnostics")
|
264 |
+
delta_l = gr.Image(label="Ξ L", height=180)
|
265 |
+
gate_l = gr.Image(label="Gate L", height=180)
|
266 |
+
delta_g = gr.Image(label="Ξ G", height=180)
|
267 |
+
gate_g = gr.Image(label="Gate G", height=180)
|
268 |
+
stats_l = gr.Textbox(label="Stats L", interactive=False)
|
269 |
+
stats_g = gr.Textbox(label="Stats G", interactive=False)
|
270 |
|
271 |
def _run(*args):
|
272 |
+
pl, npl = args[0], args[1]
|
273 |
+
al, ag = (None if v == "None" else v for v in args[2:4])
|
274 |
return infer(pl, npl, al, ag, *args[4:])
|
275 |
|
276 |
+
run_btn.click(
|
277 |
+
fn=_run,
|
278 |
+
inputs=[prompt, negative, adapter_l, adapter_g, strength, delta_scale,
|
279 |
+
sigma_scale, gpred_scale, noise, gate_prob, use_anchor, steps,
|
280 |
+
cfg_scale, scheduler, width, height, seed],
|
281 |
+
outputs=[out_img, delta_l, gate_l, delta_g, gate_g, stats_l, stats_g]
|
|
|
|
|
282 |
)
|
283 |
return demo
|
284 |
|
285 |
|
|
|
286 |
if __name__ == "__main__":
|
287 |
create_interface().launch()
|