AbstractPhil commited on
Commit
ce712b4
·
1 Parent(s): cae6d82
Files changed (1) hide show
  1. app.py +15 -34
app.py CHANGED
@@ -107,11 +107,11 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
107
  ).input_ids.to(device)
108
 
109
  with torch.no_grad():
110
- # CLIP-L embeddings (768d)
111
  clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
  neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
 
114
- # CLIP-G embeddings (1280d) - get the hidden states [0], not pooled [1]
115
  clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
116
  neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
117
 
@@ -142,30 +142,28 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
142
  if scheduler_name in SCHEDULERS:
143
  pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
 
145
- # Get T5 embeddings for semantic understanding
146
- t5_ids = t5_tok(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
 
 
 
 
 
 
147
  t5_seq = t5_mod(t5_ids).last_hidden_state
148
 
149
  # Get proper SDXL CLIP embeddings
150
  clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
151
 
 
 
 
 
 
152
  # Load adapters
153
  adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
154
  adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
155
 
156
- # Ensure all embeddings have the same sequence length (77 tokens)
157
- seq_len = 77
158
-
159
- # Resize T5 to match CLIP sequence length
160
- if t5_seq.size(1) != seq_len:
161
- t5_seq = torch.nn.functional.interpolate(
162
- t5_seq.transpose(1, 2),
163
- size=seq_len,
164
- mode="nearest"
165
- ).transpose(1, 2)
166
-
167
- print(f"After resize - T5: {t5_seq.shape}, CLIP-L: {clip_embeds['clip_l'].shape}, CLIP-G: {clip_embeds['clip_g'].shape}")
168
-
169
  # Apply CLIP-L adapter
170
  if adapter_l is not None:
171
  anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
@@ -193,23 +191,6 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
193
  clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
194
  if noise > 0:
195
  clip_g_mod += torch.randn_like(clip_g_mod) * noise
196
- else:
197
- clip_g_mod = clip_embeds["clip_g"]
198
- delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
199
- gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
200
- g_pred_g = torch.tensor(0.0)
201
- tau_g = torch.tensor(0.0) 2)
202
- else:
203
- t5_seq_resized = t5_seq
204
-
205
- anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq_resized, clip_embeds["clip_g"])
206
- gate_g_scaled = gate_g * gate_prob
207
- delta_g_final = delta_g * strength * gate_g_scaled
208
- clip_g_mod = clip_embeds["clip_g"] + delta_g_final
209
- if use_anchor:
210
- clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
211
- if noise > 0:
212
- clip_g_mod += torch.randn_like(clip_g_mod) * noise
213
  else:
214
  clip_g_mod = clip_embeds["clip_g"]
215
  delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
 
107
  ).input_ids.to(device)
108
 
109
  with torch.no_grad():
110
+ # CLIP-L embeddings (768d) - [0] is sequence, [1] is pooled
111
  clip_l_embeds = pipe.text_encoder(tokens_l)[0]
112
  neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
113
 
114
+ # CLIP-G embeddings (1280d) - [0] is sequence, [1] is pooled
115
  clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
116
  neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
117
 
 
142
  if scheduler_name in SCHEDULERS:
143
  pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
144
 
145
+ # Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP
146
+ t5_ids = t5_tok(
147
+ prompt,
148
+ return_tensors="pt",
149
+ padding="max_length",
150
+ max_length=77,
151
+ truncation=True
152
+ ).input_ids.to(device)
153
  t5_seq = t5_mod(t5_ids).last_hidden_state
154
 
155
  # Get proper SDXL CLIP embeddings
156
  clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
157
 
158
+ # Debug shapes
159
+ print(f"T5 seq shape: {t5_seq.shape}")
160
+ print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
161
+ print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
162
+
163
  # Load adapters
164
  adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
165
  adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # Apply CLIP-L adapter
168
  if adapter_l is not None:
169
  anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
 
191
  clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
192
  if noise > 0:
193
  clip_g_mod += torch.randn_like(clip_g_mod) * noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  else:
195
  clip_g_mod = clip_embeds["clip_g"]
196
  delta_g_final = torch.zeros_like(clip_embeds["clip_g"])