Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -187,26 +187,49 @@ def infer(prompt, negative_prompt,
|
|
187 |
use_anchor=use_anchor,
|
188 |
)
|
189 |
|
190 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
if adapter_l_file and adapter_l_file != "None":
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
else:
|
198 |
-
clip_l_mod = embeds["clip_l"]; delta_l = torch.zeros_like(clip_l_mod)
|
199 |
-
gate_l = torch.zeros_like(clip_l_mod[..., :1]); g_pred_l = tau_l = torch.tensor(0.)
|
200 |
-
|
201 |
-
# CLIP-G
|
202 |
if adapter_g_file and adapter_g_file != "None":
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
# concatenate for SDXL
|
212 |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|
|
|
187 |
use_anchor=use_anchor,
|
188 |
)
|
189 |
|
190 |
+
# --- STEP 0: build shift config -----------------------------------------
|
191 |
+
cfg_shift = ShiftConfig(
|
192 |
+
prompt = prompt,
|
193 |
+
seed = seed,
|
194 |
+
strength = strength,
|
195 |
+
delta_scale = delta_scale,
|
196 |
+
sigma_scale = sigma_scale,
|
197 |
+
gate_probability = gate_prob,
|
198 |
+
noise_injection = noise,
|
199 |
+
use_anchor = use_anchor,
|
200 |
+
guidance_scale = gpred_scale,
|
201 |
+
)
|
202 |
+
|
203 |
+
# --- STEP 1: encoder embeddings -----------------------------------------
|
204 |
+
t5_seq = ConditioningShifter.extract_encoder_embeddings(
|
205 |
+
{"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
|
206 |
+
device, cfg_shift
|
207 |
+
)
|
208 |
+
|
209 |
+
# --- STEP 2: run adapters -----------------------------------------------
|
210 |
+
outputs = []
|
211 |
if adapter_l_file and adapter_l_file != "None":
|
212 |
+
ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
|
213 |
+
outputs.append(ConditioningShifter.run_adapter(
|
214 |
+
ada_l, t5_seq, embeds["clip_l"],
|
215 |
+
cfg_shift.guidance_scale, "clip_l", (0, 768)))
|
216 |
+
|
|
|
|
|
|
|
|
|
|
|
217 |
if adapter_g_file and adapter_g_file != "None":
|
218 |
+
ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
|
219 |
+
outputs.append(ConditioningShifter.run_adapter(
|
220 |
+
ada_g, t5_seq, embeds["clip_g"],
|
221 |
+
cfg_shift.guidance_scale, "clip_g", (768, 2048)))
|
222 |
+
|
223 |
+
# --- STEP 3: apply mods --------------------------------------------------
|
224 |
+
clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
|
225 |
+
for out in outputs:
|
226 |
+
tgt = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
|
227 |
+
mod = ConditioningShifter.apply_modifications(tgt, [out], cfg_shift)
|
228 |
+
if out.adapter_type == "clip_l":
|
229 |
+
clip_l_mod = mod
|
230 |
+
else:
|
231 |
+
clip_g_mod = mod
|
232 |
+
|
233 |
|
234 |
# concatenate for SDXL
|
235 |
prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
|