import safetensors.torch as st import torch from diffusers import StableDiffusionXLPipeline from transformers import T5TokenizerFast, T5EncoderModel import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm.auto import tqdm # ───────────────────────────────────────────────────────────── # ░ Two-Stream Shunt Adapter # ───────────────────────────────────────────────────────────── class TwoStreamShuntAdapter(nn.Module): """ Cross-attentive adapter that aligns T5 and CLIP token streams. Returns: anchor : (B, Lc, clip_dim) delta : (B, Lc, clip_dim) log_sigma : (B, Lc, clip_dim) – log σ, always finite attn_t2c : (B, heads, Lt, Lc) attn_c2t : (B, heads, Lc, Lt) tau : (heads, 1, 1) – per-head threshold param g_pred : (B, 1) – guidance-scale prediction gate : (B, Lc, 1) – per-token gate ∈ (0,1) """ def __init__( self, t5_dim: int = 512, clip_dim: int = 768, bottleneck: int = 256, heads: int = 8, tau_init: float = 0.1, max_guidance: float = 10.0, ): super().__init__() print("TwoStreamShuntAdapter init") self.heads = heads self.bneck = bottleneck self.max_guidance = max_guidance # projections self.proj_t5 = nn.Linear(t5_dim, bottleneck) self.proj_clip = nn.Linear(clip_dim, bottleneck) # cross-attention self.cross_t2c = nn.MultiheadAttention( bottleneck, heads, batch_first=True, dropout=0.1 ) self.cross_c2t = nn.MultiheadAttention( bottleneck, heads, batch_first=True, dropout=0.1 ) # head-wise τ self.tau = nn.Parameter(torch.full((heads, 1, 1), tau_init)) # convolutional pocket residual (depth-wise) self.res1 = nn.Conv1d( bottleneck, bottleneck, 3, padding=1, groups=bottleneck ) self.res2 = nn.Conv1d( bottleneck, bottleneck, 3, padding=1, groups=bottleneck ) self.norm_res = nn.LayerNorm(bottleneck) # fusion + projections self.fuse = nn.Linear(2 * bottleneck, bottleneck) self.anchor_proj = nn.Sequential( nn.Linear(bottleneck, bottleneck), nn.GELU(), nn.Linear(bottleneck, clip_dim) ) self.delta_proj = nn.Sequential( nn.Linear(bottleneck, bottleneck), nn.GELU(), nn.Linear(bottleneck, clip_dim) ) self.logsig_proj = nn.Sequential( nn.Linear(bottleneck, bottleneck), nn.GELU(), nn.Linear(bottleneck, clip_dim) ) self.gate_proj = nn.Sequential( nn.Linear(bottleneck, bottleneck), nn.GELU(), nn.Linear(bottleneck, 1), nn.Sigmoid() ) self.guidance_proj = nn.Sequential( nn.LayerNorm(bottleneck), nn.Linear(bottleneck, 1), nn.Sigmoid() ) def load_state_dict(self, args, **kwargs): # remove _orig_mod from state dict before applying. state_dict = {k.replace("_orig_mod.", ""): v for k, v in args.items()} super().load_state_dict(state_dict, **kwargs) def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor): print("📣 SHUNT FORWARD CALLED") B, Lt, _ = t5_seq.size() _, Lc, _ = clip_seq.size() # 1) project into bottleneck t5_b = self.proj_t5(t5_seq) # (B, Lt, b) clip_b = self.proj_clip(clip_seq) # (B, Lc, b) # 2) cross-attention t2c, attn_t2c = self.cross_t2c( t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False ) c2t, attn_c2t = self.cross_c2t( clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False ) # 3) convolutional pocket on T5→CLIP x = t2c.transpose(1, 2) # (B, b, Lt) x = F.gelu(self.res1(x)) x = F.gelu(self.res2(x)).transpose(1, 2) # (B, Lt, b) pocket = self.norm_res(t2c + x) # (B, Lt, b) # 4) fuse pocket avg with C2T pocket_mean = pocket.mean(1, keepdim=True).expand(-1, Lc, -1) h = F.gelu(self.fuse(torch.cat([pocket_mean, c2t], -1))) # (B, Lc, b) # 5) outputs anchor = self.anchor_proj(h) # (B,Lc,768) delta_mean = self.delta_proj(h) # (B,Lc,768) log_sigma = self.logsig_proj(h) # (B,Lc,768) gate = self.gate_proj(h) # (B,Lc,1) delta = delta_mean * gate # (B,Lc,768) g_tok = self.guidance_proj(h).squeeze(-1) # (B,Lc) g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance #print(anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate) return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate # --- 1. load pipeline ------------------------------------------------- pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda") # --- 2. load tiny-T5 & shunt (fp32) ----------------------------------- t5_tok = T5TokenizerFast.from_pretrained("t5-small") t5_mod = T5EncoderModel.from_pretrained("t5-small").eval().to("cuda") shunt = TwoStreamShuntAdapter().float().eval().to("cuda") shunt.load_state_dict( st.load_file("/content/drive/MyDrive/t5-clip-l-shunts/vitl14_t5small_shunt_vanilla_final.safetensors") ) # --- 3. wrap encode_prompt once --------------------------------------- orig_encode = pipe.encode_prompt config = { "strength": 1.0, "gate_gamma": 1.0, "tau_scale": 1.0, "guidance_gain": 1.0, "guidance_bias": 0.0 } gen = torch.Generator(device="cuda").manual_seed(420) strength = 0 # the working version that can't be omitted, def stable_encode_prompt_shunted(self, *args, **kw): pe, ne, pool, npool = orig_encode(*args, **kw) # regular call # 👉 split: first 768 dims are CLIP-L, rest 1280 are CLIP-G clipL, clipG = pe[..., :768], pe[..., 768:] # build T5 batch (handles CFG dup automatically because # encode_prompt already concatenated negative & positive if needed) bsz = clipL.shape[0] texts = ["tmp"] * bsz # dummy, we only care about hidden states t5_ids = t5_tok(texts, return_tensors="pt").input_ids.to("cuda") t5_seq = t5_mod(t5_ids).last_hidden_state # (B,L,512) # run adapter in fp32 delta = shunt(t5_seq.float(), clipL.float())[1] # second output is Δ delta = delta * strength # << your strength knob clipL_shift = (clipL.float() + delta).to(clipL.dtype) pe_shifted = torch.cat([clipL_shift, clipG], dim=-1) return pe_shifted, ne, pool, npool #----------------------------------------------------------------------------------------- def encode_prompt_shunted(self, *a, **k): # 1) run the normal encoder with “style” & “context” already split pe, ne, pool, npool = orig_encode(*a, **k) # (B,77,2048) # 2) split CLIP-L / CLIP-G clipL, clipG = pe[..., :768], pe[..., 768:] # 3) build T5 on the *context* text (it’s in k['prompt_2']) t5_ids = t5_tok([k.get("prompt_2")], return_tensors="pt").input_ids.to(pe.device) t5_seq = t5_mod(t5_ids).last_hidden_state.float() # 4) shunt → Δ (FP32 → back-cast) Δ = shunt(t5_seq, clipL.float())[1].to(clipL.dtype) clipL_shift = clipL + Δ * strength # 5) concatenate back pe_shift = torch.cat([clipL_shift, clipG], dim=-1) return pe_shift, ne, pool, npool pipe.encode_prompt = encode_prompt_shunted.__get__(pipe, type(pipe)) PROMPT = "a naturally lit and beautiful room with a photorealistic depiction of a woman" PROMPT_2 = "a realistic depiction of a woman sitting on a chair at a coffee shop sipping coffee, the environment is beautiful" NEG = "blurry, distorted, monochrome, greyscale, watermark" STEPS = 50 base_strength = 0.5 base_cfg = 7.5 for i in range(0, 4): strength = base_strength + (i * 0.25) cfg = base_cfg - (i * 0.25) img = pipe( PROMPT, prompt_2=PROMPT_2, negative_prompt=NEG, num_inference_steps=STEPS, cfg_scale=cfg, generator=torch.Generator(device="cuda").manual_seed(420) ).images[0] img.save(f"woman_cfg_{int(cfg*100)}_{int(strength*100)}.png") # --- 4. generate ------------------------------------------------------- #img = pipe( # PROMPT, # negative_prompt=NEG, # num_inference_steps=STEPS, # generator=torch.Generator(device="cuda").manual_seed(420) # ).images[0] #img.save("majestic_baseline.png")# # #strength = 0.25 ## --- 4. generate ------------------------------------------------------- #img = pipe( # PROMPT, # negative_prompt=NEG, # num_inference_steps=STEPS, # generator=torch.Generator(device="cuda").manual_seed(420) # ).images[0] #img.save("majestic_02.png")# #strength = 0.5 ## --- 4. generate ------------------------------------------------------- #img = pipe( # PROMPT, # negative_prompt=NEG, # num_inference_steps=STEPS, # generator=torch.Generator(device="cuda").manual_seed(420) # ).images[0] #img.save("majestic_05.png")# #strength = 0.75 ## --- 4. generate ------------------------------------------------------- #img = pipe( # PROMPT, # negative_prompt=NEG, # num_inference_steps=STEPS, # generator=torch.Generator(device="cuda").manual_seed(420) # ).images[0] #img.save("majestic_075.png")