|
import safetensors.torch as st |
|
import torch |
|
from diffusers import StableDiffusionXLPipeline |
|
from transformers import T5TokenizerFast, T5EncoderModel |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
|
class TwoStreamShuntAdapter(nn.Module): |
|
""" |
|
Cross-attentive adapter that aligns T5 and CLIP token streams. |
|
|
|
Returns: |
|
anchor : (B, Lc, clip_dim) |
|
delta : (B, Lc, clip_dim) |
|
log_sigma : (B, Lc, clip_dim) β log Ο, always finite |
|
attn_t2c : (B, heads, Lt, Lc) |
|
attn_c2t : (B, heads, Lc, Lt) |
|
tau : (heads, 1, 1) β per-head threshold param |
|
g_pred : (B, 1) β guidance-scale prediction |
|
gate : (B, Lc, 1) β per-token gate β (0,1) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
t5_dim: int = 512, |
|
clip_dim: int = 768, |
|
bottleneck: int = 256, |
|
heads: int = 8, |
|
tau_init: float = 0.1, |
|
max_guidance: float = 10.0, |
|
): |
|
super().__init__() |
|
print("TwoStreamShuntAdapter init") |
|
self.heads = heads |
|
self.bneck = bottleneck |
|
self.max_guidance = max_guidance |
|
|
|
|
|
self.proj_t5 = nn.Linear(t5_dim, bottleneck) |
|
self.proj_clip = nn.Linear(clip_dim, bottleneck) |
|
|
|
|
|
self.cross_t2c = nn.MultiheadAttention( |
|
bottleneck, heads, batch_first=True, dropout=0.1 |
|
) |
|
self.cross_c2t = nn.MultiheadAttention( |
|
bottleneck, heads, batch_first=True, dropout=0.1 |
|
) |
|
|
|
|
|
self.tau = nn.Parameter(torch.full((heads, 1, 1), tau_init)) |
|
|
|
|
|
self.res1 = nn.Conv1d( |
|
bottleneck, bottleneck, 3, padding=1, groups=bottleneck |
|
) |
|
self.res2 = nn.Conv1d( |
|
bottleneck, bottleneck, 3, padding=1, groups=bottleneck |
|
) |
|
self.norm_res = nn.LayerNorm(bottleneck) |
|
|
|
|
|
self.fuse = nn.Linear(2 * bottleneck, bottleneck) |
|
|
|
self.anchor_proj = nn.Sequential( |
|
nn.Linear(bottleneck, bottleneck), nn.GELU(), |
|
nn.Linear(bottleneck, clip_dim) |
|
) |
|
self.delta_proj = nn.Sequential( |
|
nn.Linear(bottleneck, bottleneck), nn.GELU(), |
|
nn.Linear(bottleneck, clip_dim) |
|
) |
|
self.logsig_proj = nn.Sequential( |
|
nn.Linear(bottleneck, bottleneck), nn.GELU(), |
|
nn.Linear(bottleneck, clip_dim) |
|
) |
|
self.gate_proj = nn.Sequential( |
|
nn.Linear(bottleneck, bottleneck), nn.GELU(), |
|
nn.Linear(bottleneck, 1), nn.Sigmoid() |
|
) |
|
self.guidance_proj = nn.Sequential( |
|
nn.LayerNorm(bottleneck), nn.Linear(bottleneck, 1), nn.Sigmoid() |
|
) |
|
|
|
def load_state_dict(self, args, **kwargs): |
|
|
|
state_dict = {k.replace("_orig_mod.", ""): v for k, v in args.items()} |
|
super().load_state_dict(state_dict, **kwargs) |
|
|
|
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor): |
|
print("π£ SHUNT FORWARD CALLED") |
|
|
|
B, Lt, _ = t5_seq.size() |
|
_, Lc, _ = clip_seq.size() |
|
|
|
|
|
t5_b = self.proj_t5(t5_seq) |
|
clip_b = self.proj_clip(clip_seq) |
|
|
|
|
|
t2c, attn_t2c = self.cross_t2c( |
|
t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False |
|
) |
|
c2t, attn_c2t = self.cross_c2t( |
|
clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False |
|
) |
|
|
|
|
|
x = t2c.transpose(1, 2) |
|
x = F.gelu(self.res1(x)) |
|
x = F.gelu(self.res2(x)).transpose(1, 2) |
|
pocket = self.norm_res(t2c + x) |
|
|
|
|
|
pocket_mean = pocket.mean(1, keepdim=True).expand(-1, Lc, -1) |
|
h = F.gelu(self.fuse(torch.cat([pocket_mean, c2t], -1))) |
|
|
|
|
|
anchor = self.anchor_proj(h) |
|
delta_mean = self.delta_proj(h) |
|
log_sigma = self.logsig_proj(h) |
|
gate = self.gate_proj(h) |
|
delta = delta_mean * gate |
|
|
|
g_tok = self.guidance_proj(h).squeeze(-1) |
|
g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance |
|
|
|
|
|
|
|
return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate |
|
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.float16).to("cuda") |
|
|
|
|
|
t5_tok = T5TokenizerFast.from_pretrained("t5-small") |
|
t5_mod = T5EncoderModel.from_pretrained("t5-small").eval().to("cuda") |
|
shunt = TwoStreamShuntAdapter().float().eval().to("cuda") |
|
shunt.load_state_dict( st.load_file("/content/drive/MyDrive/t5-clip-l-shunts/vitl14_t5small_shunt_vanilla_final.safetensors") ) |
|
|
|
|
|
orig_encode = pipe.encode_prompt |
|
|
|
config = { |
|
"strength": 1.0, |
|
"gate_gamma": 1.0, |
|
"tau_scale": 1.0, |
|
"guidance_gain": 1.0, |
|
"guidance_bias": 0.0 |
|
} |
|
|
|
|
|
gen = torch.Generator(device="cuda").manual_seed(420) |
|
|
|
|
|
|
|
strength = 0 |
|
|
|
|
|
def stable_encode_prompt_shunted(self, *args, **kw): |
|
pe, ne, pool, npool = orig_encode(*args, **kw) |
|
|
|
|
|
clipL, clipG = pe[..., :768], pe[..., 768:] |
|
|
|
|
|
|
|
bsz = clipL.shape[0] |
|
texts = ["tmp"] * bsz |
|
t5_ids = t5_tok(texts, return_tensors="pt").input_ids.to("cuda") |
|
t5_seq = t5_mod(t5_ids).last_hidden_state |
|
|
|
|
|
delta = shunt(t5_seq.float(), clipL.float())[1] |
|
delta = delta * strength |
|
clipL_shift = (clipL.float() + delta).to(clipL.dtype) |
|
|
|
pe_shifted = torch.cat([clipL_shift, clipG], dim=-1) |
|
return pe_shifted, ne, pool, npool |
|
|
|
|
|
def encode_prompt_shunted(self, *a, **k): |
|
|
|
pe, ne, pool, npool = orig_encode(*a, **k) |
|
|
|
|
|
clipL, clipG = pe[..., :768], pe[..., 768:] |
|
|
|
|
|
t5_ids = t5_tok([k.get("prompt_2")], return_tensors="pt").input_ids.to(pe.device) |
|
t5_seq = t5_mod(t5_ids).last_hidden_state.float() |
|
|
|
|
|
Ξ = shunt(t5_seq, clipL.float())[1].to(clipL.dtype) |
|
clipL_shift = clipL + Ξ * strength |
|
|
|
|
|
pe_shift = torch.cat([clipL_shift, clipG], dim=-1) |
|
return pe_shift, ne, pool, npool |
|
|
|
pipe.encode_prompt = encode_prompt_shunted.__get__(pipe, type(pipe)) |
|
|
|
|
|
|
|
|
|
PROMPT = "a naturally lit and beautiful room with a photorealistic depiction of a woman" |
|
PROMPT_2 = "a realistic depiction of a woman sitting on a chair at a coffee shop sipping coffee, the environment is beautiful" |
|
NEG = "blurry, distorted, monochrome, greyscale, watermark" |
|
STEPS = 50 |
|
base_strength = 0.5 |
|
base_cfg = 7.5 |
|
|
|
|
|
for i in range(0, 4): |
|
strength = base_strength + (i * 0.25) |
|
cfg = base_cfg - (i * 0.25) |
|
img = pipe( |
|
PROMPT, |
|
prompt_2=PROMPT_2, |
|
negative_prompt=NEG, |
|
num_inference_steps=STEPS, |
|
cfg_scale=cfg, |
|
generator=torch.Generator(device="cuda").manual_seed(420) |
|
).images[0] |
|
img.save(f"woman_cfg_{int(cfg*100)}_{int(strength*100)}.png") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|