Any-to-Any
AbstractPhil commited on
Commit
99cb0de
·
verified ·
1 Parent(s): 1bcc59e

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +284 -0
pipeline.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import safetensors.torch as st
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline
4
+ from transformers import T5TokenizerFast, T5EncoderModel
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.auto import tqdm
12
+
13
+ # ─────────────────────────────────────────────────────────────
14
+ # ░ Two-Stream Shunt Adapter
15
+ # ─────────────────────────────────────────────────────────────
16
+ class TwoStreamShuntAdapter(nn.Module):
17
+ """
18
+ Cross-attentive adapter that aligns T5 and CLIP token streams.
19
+
20
+ Returns:
21
+ anchor : (B, Lc, clip_dim)
22
+ delta : (B, Lc, clip_dim)
23
+ log_sigma : (B, Lc, clip_dim) – log σ, always finite
24
+ attn_t2c : (B, heads, Lt, Lc)
25
+ attn_c2t : (B, heads, Lc, Lt)
26
+ tau : (heads, 1, 1) – per-head threshold param
27
+ g_pred : (B, 1) – guidance-scale prediction
28
+ gate : (B, Lc, 1) – per-token gate ∈ (0,1)
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ t5_dim: int = 512,
34
+ clip_dim: int = 768,
35
+ bottleneck: int = 256,
36
+ heads: int = 8,
37
+ tau_init: float = 0.1,
38
+ max_guidance: float = 10.0,
39
+ ):
40
+ super().__init__()
41
+ print("TwoStreamShuntAdapter init")
42
+ self.heads = heads
43
+ self.bneck = bottleneck
44
+ self.max_guidance = max_guidance
45
+
46
+ # projections
47
+ self.proj_t5 = nn.Linear(t5_dim, bottleneck)
48
+ self.proj_clip = nn.Linear(clip_dim, bottleneck)
49
+
50
+ # cross-attention
51
+ self.cross_t2c = nn.MultiheadAttention(
52
+ bottleneck, heads, batch_first=True, dropout=0.1
53
+ )
54
+ self.cross_c2t = nn.MultiheadAttention(
55
+ bottleneck, heads, batch_first=True, dropout=0.1
56
+ )
57
+
58
+ # head-wise τ
59
+ self.tau = nn.Parameter(torch.full((heads, 1, 1), tau_init))
60
+
61
+ # convolutional pocket residual (depth-wise)
62
+ self.res1 = nn.Conv1d(
63
+ bottleneck, bottleneck, 3, padding=1, groups=bottleneck
64
+ )
65
+ self.res2 = nn.Conv1d(
66
+ bottleneck, bottleneck, 3, padding=1, groups=bottleneck
67
+ )
68
+ self.norm_res = nn.LayerNorm(bottleneck)
69
+
70
+ # fusion + projections
71
+ self.fuse = nn.Linear(2 * bottleneck, bottleneck)
72
+
73
+ self.anchor_proj = nn.Sequential(
74
+ nn.Linear(bottleneck, bottleneck), nn.GELU(),
75
+ nn.Linear(bottleneck, clip_dim)
76
+ )
77
+ self.delta_proj = nn.Sequential(
78
+ nn.Linear(bottleneck, bottleneck), nn.GELU(),
79
+ nn.Linear(bottleneck, clip_dim)
80
+ )
81
+ self.logsig_proj = nn.Sequential(
82
+ nn.Linear(bottleneck, bottleneck), nn.GELU(),
83
+ nn.Linear(bottleneck, clip_dim)
84
+ )
85
+ self.gate_proj = nn.Sequential(
86
+ nn.Linear(bottleneck, bottleneck), nn.GELU(),
87
+ nn.Linear(bottleneck, 1), nn.Sigmoid()
88
+ )
89
+ self.guidance_proj = nn.Sequential(
90
+ nn.LayerNorm(bottleneck), nn.Linear(bottleneck, 1), nn.Sigmoid()
91
+ )
92
+
93
+ def load_state_dict(self, args, **kwargs):
94
+ # remove _orig_mod from state dict before applying.
95
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in args.items()}
96
+ super().load_state_dict(state_dict, **kwargs)
97
+
98
+ def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
99
+ print("📣 SHUNT FORWARD CALLED")
100
+
101
+ B, Lt, _ = t5_seq.size()
102
+ _, Lc, _ = clip_seq.size()
103
+
104
+ # 1) project into bottleneck
105
+ t5_b = self.proj_t5(t5_seq) # (B, Lt, b)
106
+ clip_b = self.proj_clip(clip_seq) # (B, Lc, b)
107
+
108
+ # 2) cross-attention
109
+ t2c, attn_t2c = self.cross_t2c(
110
+ t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False
111
+ )
112
+ c2t, attn_c2t = self.cross_c2t(
113
+ clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False
114
+ )
115
+
116
+ # 3) convolutional pocket on T5→CLIP
117
+ x = t2c.transpose(1, 2) # (B, b, Lt)
118
+ x = F.gelu(self.res1(x))
119
+ x = F.gelu(self.res2(x)).transpose(1, 2) # (B, Lt, b)
120
+ pocket = self.norm_res(t2c + x) # (B, Lt, b)
121
+
122
+ # 4) fuse pocket avg with C2T
123
+ pocket_mean = pocket.mean(1, keepdim=True).expand(-1, Lc, -1)
124
+ h = F.gelu(self.fuse(torch.cat([pocket_mean, c2t], -1))) # (B, Lc, b)
125
+
126
+ # 5) outputs
127
+ anchor = self.anchor_proj(h) # (B,Lc,768)
128
+ delta_mean = self.delta_proj(h) # (B,Lc,768)
129
+ log_sigma = self.logsig_proj(h) # (B,Lc,768)
130
+ gate = self.gate_proj(h) # (B,Lc,1)
131
+ delta = delta_mean * gate # (B,Lc,768)
132
+
133
+ g_tok = self.guidance_proj(h).squeeze(-1) # (B,Lc)
134
+ g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
135
+
136
+ #print(anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate)
137
+
138
+ return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate
139
+
140
+ # --- 1. load pipeline -------------------------------------------------
141
+ pipe = StableDiffusionXLPipeline.from_pretrained(
142
+ "stabilityai/stable-diffusion-xl-base-1.0",
143
+ torch_dtype=torch.float16).to("cuda")
144
+
145
+ # --- 2. load tiny-T5 & shunt (fp32) -----------------------------------
146
+ t5_tok = T5TokenizerFast.from_pretrained("t5-small")
147
+ t5_mod = T5EncoderModel.from_pretrained("t5-small").eval().to("cuda")
148
+ shunt = TwoStreamShuntAdapter().float().eval().to("cuda")
149
+ shunt.load_state_dict( st.load_file("/content/drive/MyDrive/t5-clip-l-shunts/vitl14_t5small_shunt_vanilla_final.safetensors") )
150
+
151
+ # --- 3. wrap encode_prompt once ---------------------------------------
152
+ orig_encode = pipe.encode_prompt
153
+
154
+ config = {
155
+ "strength": 1.0,
156
+ "gate_gamma": 1.0,
157
+ "tau_scale": 1.0,
158
+ "guidance_gain": 1.0,
159
+ "guidance_bias": 0.0
160
+ }
161
+
162
+
163
+ gen = torch.Generator(device="cuda").manual_seed(420)
164
+
165
+
166
+
167
+ strength = 0
168
+
169
+ # the working version that can't be omitted,
170
+ def stable_encode_prompt_shunted(self, *args, **kw):
171
+ pe, ne, pool, npool = orig_encode(*args, **kw) # regular call
172
+
173
+ # 👉 split: first 768 dims are CLIP-L, rest 1280 are CLIP-G
174
+ clipL, clipG = pe[..., :768], pe[..., 768:]
175
+
176
+ # build T5 batch (handles CFG dup automatically because
177
+ # encode_prompt already concatenated negative & positive if needed)
178
+ bsz = clipL.shape[0]
179
+ texts = ["tmp"] * bsz # dummy, we only care about hidden states
180
+ t5_ids = t5_tok(texts, return_tensors="pt").input_ids.to("cuda")
181
+ t5_seq = t5_mod(t5_ids).last_hidden_state # (B,L,512)
182
+
183
+ # run adapter in fp32
184
+ delta = shunt(t5_seq.float(), clipL.float())[1] # second output is Δ
185
+ delta = delta * strength # << your strength knob
186
+ clipL_shift = (clipL.float() + delta).to(clipL.dtype)
187
+
188
+ pe_shifted = torch.cat([clipL_shift, clipG], dim=-1)
189
+ return pe_shifted, ne, pool, npool
190
+ #-----------------------------------------------------------------------------------------
191
+
192
+ def encode_prompt_shunted(self, *a, **k):
193
+ # 1) run the normal encoder with “style” & “context” already split
194
+ pe, ne, pool, npool = orig_encode(*a, **k) # (B,77,2048)
195
+
196
+ # 2) split CLIP-L / CLIP-G
197
+ clipL, clipG = pe[..., :768], pe[..., 768:]
198
+
199
+ # 3) build T5 on the *context* text (it’s in k['prompt_2'])
200
+ t5_ids = t5_tok([k.get("prompt_2")], return_tensors="pt").input_ids.to(pe.device)
201
+ t5_seq = t5_mod(t5_ids).last_hidden_state.float()
202
+
203
+ # 4) shunt → Δ (FP32 → back-cast)
204
+ Δ = shunt(t5_seq, clipL.float())[1].to(clipL.dtype)
205
+ clipL_shift = clipL + Δ * strength
206
+
207
+ # 5) concatenate back
208
+ pe_shift = torch.cat([clipL_shift, clipG], dim=-1)
209
+ return pe_shift, ne, pool, npool
210
+
211
+ pipe.encode_prompt = encode_prompt_shunted.__get__(pipe, type(pipe))
212
+
213
+
214
+
215
+
216
+ PROMPT = "a naturally lit and beautiful room with a photorealistic depiction of a woman"
217
+ PROMPT_2 = "a realistic depiction of a woman sitting on a chair at a coffee shop sipping coffee, the environment is beautiful"
218
+ NEG = "blurry, distorted, monochrome, greyscale, watermark"
219
+ STEPS = 50
220
+ base_strength = 0.5
221
+ base_cfg = 7.5
222
+
223
+
224
+ for i in range(0, 4):
225
+ strength = base_strength + (i * 0.25)
226
+ cfg = base_cfg - (i * 0.25)
227
+ img = pipe(
228
+ PROMPT,
229
+ prompt_2=PROMPT_2,
230
+ negative_prompt=NEG,
231
+ num_inference_steps=STEPS,
232
+ cfg_scale=cfg,
233
+ generator=torch.Generator(device="cuda").manual_seed(420)
234
+ ).images[0]
235
+ img.save(f"woman_cfg_{int(cfg*100)}_{int(strength*100)}.png")
236
+
237
+ # --- 4. generate -------------------------------------------------------
238
+ #img = pipe(
239
+ # PROMPT,
240
+ # negative_prompt=NEG,
241
+ # num_inference_steps=STEPS,
242
+ # generator=torch.Generator(device="cuda").manual_seed(420)
243
+ # ).images[0]
244
+ #img.save("majestic_baseline.png")#
245
+ #
246
+
247
+ #strength = 0.25
248
+ ## --- 4. generate -------------------------------------------------------
249
+ #img = pipe(
250
+ # PROMPT,
251
+ # negative_prompt=NEG,
252
+ # num_inference_steps=STEPS,
253
+ # generator=torch.Generator(device="cuda").manual_seed(420)
254
+ # ).images[0]
255
+ #img.save("majestic_02.png")#
256
+
257
+ #strength = 0.5
258
+ ## --- 4. generate -------------------------------------------------------
259
+ #img = pipe(
260
+ # PROMPT,
261
+ # negative_prompt=NEG,
262
+ # num_inference_steps=STEPS,
263
+ # generator=torch.Generator(device="cuda").manual_seed(420)
264
+ # ).images[0]
265
+ #img.save("majestic_05.png")#
266
+
267
+ #strength = 0.75
268
+ ## --- 4. generate -------------------------------------------------------
269
+ #img = pipe(
270
+ # PROMPT,
271
+ # negative_prompt=NEG,
272
+ # num_inference_steps=STEPS,
273
+ # generator=torch.Generator(device="cuda").manual_seed(420)
274
+ # ).images[0]
275
+ #img.save("majestic_075.png")
276
+
277
+
278
+
279
+
280
+
281
+
282
+
283
+
284
+