Create pipeline.py
Browse files- pipeline.py +284 -0
pipeline.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import safetensors.torch as st
|
2 |
+
import torch
|
3 |
+
from diffusers import StableDiffusionXLPipeline
|
4 |
+
from transformers import T5TokenizerFast, T5EncoderModel
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
|
13 |
+
# ─────────────────────────────────────────────────────────────
|
14 |
+
# ░ Two-Stream Shunt Adapter
|
15 |
+
# ─────────────────────────────────────────────────────────────
|
16 |
+
class TwoStreamShuntAdapter(nn.Module):
|
17 |
+
"""
|
18 |
+
Cross-attentive adapter that aligns T5 and CLIP token streams.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
anchor : (B, Lc, clip_dim)
|
22 |
+
delta : (B, Lc, clip_dim)
|
23 |
+
log_sigma : (B, Lc, clip_dim) – log σ, always finite
|
24 |
+
attn_t2c : (B, heads, Lt, Lc)
|
25 |
+
attn_c2t : (B, heads, Lc, Lt)
|
26 |
+
tau : (heads, 1, 1) – per-head threshold param
|
27 |
+
g_pred : (B, 1) – guidance-scale prediction
|
28 |
+
gate : (B, Lc, 1) – per-token gate ∈ (0,1)
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
t5_dim: int = 512,
|
34 |
+
clip_dim: int = 768,
|
35 |
+
bottleneck: int = 256,
|
36 |
+
heads: int = 8,
|
37 |
+
tau_init: float = 0.1,
|
38 |
+
max_guidance: float = 10.0,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
print("TwoStreamShuntAdapter init")
|
42 |
+
self.heads = heads
|
43 |
+
self.bneck = bottleneck
|
44 |
+
self.max_guidance = max_guidance
|
45 |
+
|
46 |
+
# projections
|
47 |
+
self.proj_t5 = nn.Linear(t5_dim, bottleneck)
|
48 |
+
self.proj_clip = nn.Linear(clip_dim, bottleneck)
|
49 |
+
|
50 |
+
# cross-attention
|
51 |
+
self.cross_t2c = nn.MultiheadAttention(
|
52 |
+
bottleneck, heads, batch_first=True, dropout=0.1
|
53 |
+
)
|
54 |
+
self.cross_c2t = nn.MultiheadAttention(
|
55 |
+
bottleneck, heads, batch_first=True, dropout=0.1
|
56 |
+
)
|
57 |
+
|
58 |
+
# head-wise τ
|
59 |
+
self.tau = nn.Parameter(torch.full((heads, 1, 1), tau_init))
|
60 |
+
|
61 |
+
# convolutional pocket residual (depth-wise)
|
62 |
+
self.res1 = nn.Conv1d(
|
63 |
+
bottleneck, bottleneck, 3, padding=1, groups=bottleneck
|
64 |
+
)
|
65 |
+
self.res2 = nn.Conv1d(
|
66 |
+
bottleneck, bottleneck, 3, padding=1, groups=bottleneck
|
67 |
+
)
|
68 |
+
self.norm_res = nn.LayerNorm(bottleneck)
|
69 |
+
|
70 |
+
# fusion + projections
|
71 |
+
self.fuse = nn.Linear(2 * bottleneck, bottleneck)
|
72 |
+
|
73 |
+
self.anchor_proj = nn.Sequential(
|
74 |
+
nn.Linear(bottleneck, bottleneck), nn.GELU(),
|
75 |
+
nn.Linear(bottleneck, clip_dim)
|
76 |
+
)
|
77 |
+
self.delta_proj = nn.Sequential(
|
78 |
+
nn.Linear(bottleneck, bottleneck), nn.GELU(),
|
79 |
+
nn.Linear(bottleneck, clip_dim)
|
80 |
+
)
|
81 |
+
self.logsig_proj = nn.Sequential(
|
82 |
+
nn.Linear(bottleneck, bottleneck), nn.GELU(),
|
83 |
+
nn.Linear(bottleneck, clip_dim)
|
84 |
+
)
|
85 |
+
self.gate_proj = nn.Sequential(
|
86 |
+
nn.Linear(bottleneck, bottleneck), nn.GELU(),
|
87 |
+
nn.Linear(bottleneck, 1), nn.Sigmoid()
|
88 |
+
)
|
89 |
+
self.guidance_proj = nn.Sequential(
|
90 |
+
nn.LayerNorm(bottleneck), nn.Linear(bottleneck, 1), nn.Sigmoid()
|
91 |
+
)
|
92 |
+
|
93 |
+
def load_state_dict(self, args, **kwargs):
|
94 |
+
# remove _orig_mod from state dict before applying.
|
95 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in args.items()}
|
96 |
+
super().load_state_dict(state_dict, **kwargs)
|
97 |
+
|
98 |
+
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
|
99 |
+
print("📣 SHUNT FORWARD CALLED")
|
100 |
+
|
101 |
+
B, Lt, _ = t5_seq.size()
|
102 |
+
_, Lc, _ = clip_seq.size()
|
103 |
+
|
104 |
+
# 1) project into bottleneck
|
105 |
+
t5_b = self.proj_t5(t5_seq) # (B, Lt, b)
|
106 |
+
clip_b = self.proj_clip(clip_seq) # (B, Lc, b)
|
107 |
+
|
108 |
+
# 2) cross-attention
|
109 |
+
t2c, attn_t2c = self.cross_t2c(
|
110 |
+
t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False
|
111 |
+
)
|
112 |
+
c2t, attn_c2t = self.cross_c2t(
|
113 |
+
clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False
|
114 |
+
)
|
115 |
+
|
116 |
+
# 3) convolutional pocket on T5→CLIP
|
117 |
+
x = t2c.transpose(1, 2) # (B, b, Lt)
|
118 |
+
x = F.gelu(self.res1(x))
|
119 |
+
x = F.gelu(self.res2(x)).transpose(1, 2) # (B, Lt, b)
|
120 |
+
pocket = self.norm_res(t2c + x) # (B, Lt, b)
|
121 |
+
|
122 |
+
# 4) fuse pocket avg with C2T
|
123 |
+
pocket_mean = pocket.mean(1, keepdim=True).expand(-1, Lc, -1)
|
124 |
+
h = F.gelu(self.fuse(torch.cat([pocket_mean, c2t], -1))) # (B, Lc, b)
|
125 |
+
|
126 |
+
# 5) outputs
|
127 |
+
anchor = self.anchor_proj(h) # (B,Lc,768)
|
128 |
+
delta_mean = self.delta_proj(h) # (B,Lc,768)
|
129 |
+
log_sigma = self.logsig_proj(h) # (B,Lc,768)
|
130 |
+
gate = self.gate_proj(h) # (B,Lc,1)
|
131 |
+
delta = delta_mean * gate # (B,Lc,768)
|
132 |
+
|
133 |
+
g_tok = self.guidance_proj(h).squeeze(-1) # (B,Lc)
|
134 |
+
g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
|
135 |
+
|
136 |
+
#print(anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate)
|
137 |
+
|
138 |
+
return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate
|
139 |
+
|
140 |
+
# --- 1. load pipeline -------------------------------------------------
|
141 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
142 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
143 |
+
torch_dtype=torch.float16).to("cuda")
|
144 |
+
|
145 |
+
# --- 2. load tiny-T5 & shunt (fp32) -----------------------------------
|
146 |
+
t5_tok = T5TokenizerFast.from_pretrained("t5-small")
|
147 |
+
t5_mod = T5EncoderModel.from_pretrained("t5-small").eval().to("cuda")
|
148 |
+
shunt = TwoStreamShuntAdapter().float().eval().to("cuda")
|
149 |
+
shunt.load_state_dict( st.load_file("/content/drive/MyDrive/t5-clip-l-shunts/vitl14_t5small_shunt_vanilla_final.safetensors") )
|
150 |
+
|
151 |
+
# --- 3. wrap encode_prompt once ---------------------------------------
|
152 |
+
orig_encode = pipe.encode_prompt
|
153 |
+
|
154 |
+
config = {
|
155 |
+
"strength": 1.0,
|
156 |
+
"gate_gamma": 1.0,
|
157 |
+
"tau_scale": 1.0,
|
158 |
+
"guidance_gain": 1.0,
|
159 |
+
"guidance_bias": 0.0
|
160 |
+
}
|
161 |
+
|
162 |
+
|
163 |
+
gen = torch.Generator(device="cuda").manual_seed(420)
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
strength = 0
|
168 |
+
|
169 |
+
# the working version that can't be omitted,
|
170 |
+
def stable_encode_prompt_shunted(self, *args, **kw):
|
171 |
+
pe, ne, pool, npool = orig_encode(*args, **kw) # regular call
|
172 |
+
|
173 |
+
# 👉 split: first 768 dims are CLIP-L, rest 1280 are CLIP-G
|
174 |
+
clipL, clipG = pe[..., :768], pe[..., 768:]
|
175 |
+
|
176 |
+
# build T5 batch (handles CFG dup automatically because
|
177 |
+
# encode_prompt already concatenated negative & positive if needed)
|
178 |
+
bsz = clipL.shape[0]
|
179 |
+
texts = ["tmp"] * bsz # dummy, we only care about hidden states
|
180 |
+
t5_ids = t5_tok(texts, return_tensors="pt").input_ids.to("cuda")
|
181 |
+
t5_seq = t5_mod(t5_ids).last_hidden_state # (B,L,512)
|
182 |
+
|
183 |
+
# run adapter in fp32
|
184 |
+
delta = shunt(t5_seq.float(), clipL.float())[1] # second output is Δ
|
185 |
+
delta = delta * strength # << your strength knob
|
186 |
+
clipL_shift = (clipL.float() + delta).to(clipL.dtype)
|
187 |
+
|
188 |
+
pe_shifted = torch.cat([clipL_shift, clipG], dim=-1)
|
189 |
+
return pe_shifted, ne, pool, npool
|
190 |
+
#-----------------------------------------------------------------------------------------
|
191 |
+
|
192 |
+
def encode_prompt_shunted(self, *a, **k):
|
193 |
+
# 1) run the normal encoder with “style” & “context” already split
|
194 |
+
pe, ne, pool, npool = orig_encode(*a, **k) # (B,77,2048)
|
195 |
+
|
196 |
+
# 2) split CLIP-L / CLIP-G
|
197 |
+
clipL, clipG = pe[..., :768], pe[..., 768:]
|
198 |
+
|
199 |
+
# 3) build T5 on the *context* text (it’s in k['prompt_2'])
|
200 |
+
t5_ids = t5_tok([k.get("prompt_2")], return_tensors="pt").input_ids.to(pe.device)
|
201 |
+
t5_seq = t5_mod(t5_ids).last_hidden_state.float()
|
202 |
+
|
203 |
+
# 4) shunt → Δ (FP32 → back-cast)
|
204 |
+
Δ = shunt(t5_seq, clipL.float())[1].to(clipL.dtype)
|
205 |
+
clipL_shift = clipL + Δ * strength
|
206 |
+
|
207 |
+
# 5) concatenate back
|
208 |
+
pe_shift = torch.cat([clipL_shift, clipG], dim=-1)
|
209 |
+
return pe_shift, ne, pool, npool
|
210 |
+
|
211 |
+
pipe.encode_prompt = encode_prompt_shunted.__get__(pipe, type(pipe))
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
PROMPT = "a naturally lit and beautiful room with a photorealistic depiction of a woman"
|
217 |
+
PROMPT_2 = "a realistic depiction of a woman sitting on a chair at a coffee shop sipping coffee, the environment is beautiful"
|
218 |
+
NEG = "blurry, distorted, monochrome, greyscale, watermark"
|
219 |
+
STEPS = 50
|
220 |
+
base_strength = 0.5
|
221 |
+
base_cfg = 7.5
|
222 |
+
|
223 |
+
|
224 |
+
for i in range(0, 4):
|
225 |
+
strength = base_strength + (i * 0.25)
|
226 |
+
cfg = base_cfg - (i * 0.25)
|
227 |
+
img = pipe(
|
228 |
+
PROMPT,
|
229 |
+
prompt_2=PROMPT_2,
|
230 |
+
negative_prompt=NEG,
|
231 |
+
num_inference_steps=STEPS,
|
232 |
+
cfg_scale=cfg,
|
233 |
+
generator=torch.Generator(device="cuda").manual_seed(420)
|
234 |
+
).images[0]
|
235 |
+
img.save(f"woman_cfg_{int(cfg*100)}_{int(strength*100)}.png")
|
236 |
+
|
237 |
+
# --- 4. generate -------------------------------------------------------
|
238 |
+
#img = pipe(
|
239 |
+
# PROMPT,
|
240 |
+
# negative_prompt=NEG,
|
241 |
+
# num_inference_steps=STEPS,
|
242 |
+
# generator=torch.Generator(device="cuda").manual_seed(420)
|
243 |
+
# ).images[0]
|
244 |
+
#img.save("majestic_baseline.png")#
|
245 |
+
#
|
246 |
+
|
247 |
+
#strength = 0.25
|
248 |
+
## --- 4. generate -------------------------------------------------------
|
249 |
+
#img = pipe(
|
250 |
+
# PROMPT,
|
251 |
+
# negative_prompt=NEG,
|
252 |
+
# num_inference_steps=STEPS,
|
253 |
+
# generator=torch.Generator(device="cuda").manual_seed(420)
|
254 |
+
# ).images[0]
|
255 |
+
#img.save("majestic_02.png")#
|
256 |
+
|
257 |
+
#strength = 0.5
|
258 |
+
## --- 4. generate -------------------------------------------------------
|
259 |
+
#img = pipe(
|
260 |
+
# PROMPT,
|
261 |
+
# negative_prompt=NEG,
|
262 |
+
# num_inference_steps=STEPS,
|
263 |
+
# generator=torch.Generator(device="cuda").manual_seed(420)
|
264 |
+
# ).images[0]
|
265 |
+
#img.save("majestic_05.png")#
|
266 |
+
|
267 |
+
#strength = 0.75
|
268 |
+
## --- 4. generate -------------------------------------------------------
|
269 |
+
#img = pipe(
|
270 |
+
# PROMPT,
|
271 |
+
# negative_prompt=NEG,
|
272 |
+
# num_inference_steps=STEPS,
|
273 |
+
# generator=torch.Generator(device="cuda").manual_seed(420)
|
274 |
+
# ).images[0]
|
275 |
+
#img.save("majestic_075.png")
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
|