AbstractPhil commited on
Commit
535b292
·
verified ·
1 Parent(s): 788f431

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -18
app.py CHANGED
@@ -187,26 +187,49 @@ def infer(prompt, negative_prompt,
187
  use_anchor=use_anchor,
188
  )
189
 
190
- # CLIP-L
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  if adapter_l_file and adapter_l_file != "None":
192
- cfg_l = conf_l.copy(); cfg_l.update(cfg_common)
193
- if "booru" in adapter_l_file: cfg_l["heads"] = 4
194
- adapter_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
195
- clip_l_mod, delta_l, gate_l, g_pred_l, tau_l = adapter_forward(
196
- adapter_l, t5_seq, embeds["clip_l"], cfg_l)
197
- else:
198
- clip_l_mod = embeds["clip_l"]; delta_l = torch.zeros_like(clip_l_mod)
199
- gate_l = torch.zeros_like(clip_l_mod[..., :1]); g_pred_l = tau_l = torch.tensor(0.)
200
-
201
- # CLIP-G
202
  if adapter_g_file and adapter_g_file != "None":
203
- cfg_g = conf_g.copy(); cfg_g.update(cfg_common)
204
- adapter_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
205
- clip_g_mod, delta_g, gate_g, g_pred_g, tau_g = adapter_forward(
206
- adapter_g, t5_seq, embeds["clip_g"], cfg_g)
207
- else:
208
- clip_g_mod = embeds["clip_g"]; delta_g = torch.zeros_like(clip_g_mod)
209
- gate_g = torch.zeros_like(clip_g_mod[..., :1]); g_pred_g = tau_g = torch.tensor(0.)
 
 
 
 
 
 
 
 
210
 
211
  # concatenate for SDXL
212
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)
 
187
  use_anchor=use_anchor,
188
  )
189
 
190
+ # --- STEP 0: build shift config -----------------------------------------
191
+ cfg_shift = ShiftConfig(
192
+ prompt = prompt,
193
+ seed = seed,
194
+ strength = strength,
195
+ delta_scale = delta_scale,
196
+ sigma_scale = sigma_scale,
197
+ gate_probability = gate_prob,
198
+ noise_injection = noise,
199
+ use_anchor = use_anchor,
200
+ guidance_scale = gpred_scale,
201
+ )
202
+
203
+ # --- STEP 1: encoder embeddings -----------------------------------------
204
+ t5_seq = ConditioningShifter.extract_encoder_embeddings(
205
+ {"tokenizer": _t5_tok, "model": _t5_mod, "config": {"config": {}}},
206
+ device, cfg_shift
207
+ )
208
+
209
+ # --- STEP 2: run adapters -----------------------------------------------
210
+ outputs = []
211
  if adapter_l_file and adapter_l_file != "None":
212
+ ada_l = load_adapter(repo_l, adapter_l_file, conf_l, device)
213
+ outputs.append(ConditioningShifter.run_adapter(
214
+ ada_l, t5_seq, embeds["clip_l"],
215
+ cfg_shift.guidance_scale, "clip_l", (0, 768)))
216
+
 
 
 
 
 
217
  if adapter_g_file and adapter_g_file != "None":
218
+ ada_g = load_adapter(repo_g, adapter_g_file, conf_g, device)
219
+ outputs.append(ConditioningShifter.run_adapter(
220
+ ada_g, t5_seq, embeds["clip_g"],
221
+ cfg_shift.guidance_scale, "clip_g", (768, 2048)))
222
+
223
+ # --- STEP 3: apply mods --------------------------------------------------
224
+ clip_l_mod, clip_g_mod = embeds["clip_l"], embeds["clip_g"]
225
+ for out in outputs:
226
+ tgt = clip_l_mod if out.adapter_type == "clip_l" else clip_g_mod
227
+ mod = ConditioningShifter.apply_modifications(tgt, [out], cfg_shift)
228
+ if out.adapter_type == "clip_l":
229
+ clip_l_mod = mod
230
+ else:
231
+ clip_g_mod = mod
232
+
233
 
234
  # concatenate for SDXL
235
  prompt_embeds = torch.cat([clip_l_mod, clip_g_mod], dim=-1)