Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
ce712b4
1
Parent(s):
cae6d82
yes
Browse files
app.py
CHANGED
@@ -107,11 +107,11 @@ def encode_sdxl_prompt(prompt, negative_prompt=""):
|
|
107 |
).input_ids.to(device)
|
108 |
|
109 |
with torch.no_grad():
|
110 |
-
# CLIP-L embeddings (768d)
|
111 |
clip_l_embeds = pipe.text_encoder(tokens_l)[0]
|
112 |
neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
|
113 |
|
114 |
-
# CLIP-G embeddings (1280d) -
|
115 |
clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
|
116 |
neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
|
117 |
|
@@ -142,30 +142,28 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
142 |
if scheduler_name in SCHEDULERS:
|
143 |
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
|
144 |
|
145 |
-
# Get T5 embeddings for semantic understanding
|
146 |
-
t5_ids = t5_tok(
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
t5_seq = t5_mod(t5_ids).last_hidden_state
|
148 |
|
149 |
# Get proper SDXL CLIP embeddings
|
150 |
clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
|
151 |
|
|
|
|
|
|
|
|
|
|
|
152 |
# Load adapters
|
153 |
adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
|
154 |
adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
|
155 |
|
156 |
-
# Ensure all embeddings have the same sequence length (77 tokens)
|
157 |
-
seq_len = 77
|
158 |
-
|
159 |
-
# Resize T5 to match CLIP sequence length
|
160 |
-
if t5_seq.size(1) != seq_len:
|
161 |
-
t5_seq = torch.nn.functional.interpolate(
|
162 |
-
t5_seq.transpose(1, 2),
|
163 |
-
size=seq_len,
|
164 |
-
mode="nearest"
|
165 |
-
).transpose(1, 2)
|
166 |
-
|
167 |
-
print(f"After resize - T5: {t5_seq.shape}, CLIP-L: {clip_embeds['clip_l'].shape}, CLIP-G: {clip_embeds['clip_g'].shape}")
|
168 |
-
|
169 |
# Apply CLIP-L adapter
|
170 |
if adapter_l is not None:
|
171 |
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
|
@@ -193,23 +191,6 @@ def infer(prompt, negative_prompt, adapter_l_file, adapter_g_file, strength, noi
|
|
193 |
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
|
194 |
if noise > 0:
|
195 |
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
196 |
-
else:
|
197 |
-
clip_g_mod = clip_embeds["clip_g"]
|
198 |
-
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|
199 |
-
gate_g_scaled = torch.zeros_like(clip_embeds["clip_g"])
|
200 |
-
g_pred_g = torch.tensor(0.0)
|
201 |
-
tau_g = torch.tensor(0.0) 2)
|
202 |
-
else:
|
203 |
-
t5_seq_resized = t5_seq
|
204 |
-
|
205 |
-
anchor_g, delta_g, log_sigma_g, attn_g1, attn_g2, tau_g, g_pred_g, gate_g = adapter_g(t5_seq_resized, clip_embeds["clip_g"])
|
206 |
-
gate_g_scaled = gate_g * gate_prob
|
207 |
-
delta_g_final = delta_g * strength * gate_g_scaled
|
208 |
-
clip_g_mod = clip_embeds["clip_g"] + delta_g_final
|
209 |
-
if use_anchor:
|
210 |
-
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
|
211 |
-
if noise > 0:
|
212 |
-
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
213 |
else:
|
214 |
clip_g_mod = clip_embeds["clip_g"]
|
215 |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|
|
|
107 |
).input_ids.to(device)
|
108 |
|
109 |
with torch.no_grad():
|
110 |
+
# CLIP-L embeddings (768d) - [0] is sequence, [1] is pooled
|
111 |
clip_l_embeds = pipe.text_encoder(tokens_l)[0]
|
112 |
neg_clip_l_embeds = pipe.text_encoder(neg_tokens_l)[0]
|
113 |
|
114 |
+
# CLIP-G embeddings (1280d) - [0] is sequence, [1] is pooled
|
115 |
clip_g_embeds = pipe.text_encoder_2(tokens_g)[0]
|
116 |
neg_clip_g_embeds = pipe.text_encoder_2(neg_tokens_g)[0]
|
117 |
|
|
|
142 |
if scheduler_name in SCHEDULERS:
|
143 |
pipe.scheduler = SCHEDULERS[scheduler_name].from_config(pipe.scheduler.config)
|
144 |
|
145 |
+
# Get T5 embeddings for semantic understanding - standardize to 77 tokens like CLIP
|
146 |
+
t5_ids = t5_tok(
|
147 |
+
prompt,
|
148 |
+
return_tensors="pt",
|
149 |
+
padding="max_length",
|
150 |
+
max_length=77,
|
151 |
+
truncation=True
|
152 |
+
).input_ids.to(device)
|
153 |
t5_seq = t5_mod(t5_ids).last_hidden_state
|
154 |
|
155 |
# Get proper SDXL CLIP embeddings
|
156 |
clip_embeds = encode_sdxl_prompt(prompt, negative_prompt)
|
157 |
|
158 |
+
# Debug shapes
|
159 |
+
print(f"T5 seq shape: {t5_seq.shape}")
|
160 |
+
print(f"CLIP-L shape: {clip_embeds['clip_l'].shape}")
|
161 |
+
print(f"CLIP-G shape: {clip_embeds['clip_g'].shape}")
|
162 |
+
|
163 |
# Load adapters
|
164 |
adapter_l = load_adapter(repo_l, adapter_l_file, config_l) if adapter_l_file else None
|
165 |
adapter_g = load_adapter(repo_g, adapter_g_file, config_g) if adapter_g_file else None
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
# Apply CLIP-L adapter
|
168 |
if adapter_l is not None:
|
169 |
anchor_l, delta_l, log_sigma_l, attn_l1, attn_l2, tau_l, g_pred_l, gate_l = adapter_l(t5_seq, clip_embeds["clip_l"])
|
|
|
191 |
clip_g_mod = clip_g_mod * (1 - gate_g_scaled) + anchor_g * gate_g_scaled
|
192 |
if noise > 0:
|
193 |
clip_g_mod += torch.randn_like(clip_g_mod) * noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
else:
|
195 |
clip_g_mod = clip_embeds["clip_g"]
|
196 |
delta_g_final = torch.zeros_like(clip_embeds["clip_g"])
|