AbstractPhil's picture
yes
7b42604
raw
history blame
6.01 kB
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import DiffusionPipeline
from safetensors.torch import safe_open
from huggingface_hub import hf_hub_download
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from adapter_config import T5_SHUNT_REPOS
# ─── Device & Model Setup ─────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None
).to(device)
# ─── Adapter Configs ──────────────────────────────────────────
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
config_g = T5_SHUNT_REPOS["clip_g"]["config"]
# ─── Loader ───────────────────────────────────────────────────
def load_adapter(repo, filename, config):
path = hf_hub_download(repo_id=repo, filename=filename)
model = TwoStreamShuntAdapter(config).eval()
tensors = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
model.load_state_dict(tensors)
model.to(device)
return model
# ─── Inference ────────────────────────────────────────────────
@torch.no_grad()
def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
adapter_list = []
# Load adapters with config
adapter_list.append({
"adapter": load_adapter(repo_l, adapter_l_file, config_l),
"config": config_l
})
adapter_list.append({
"adapter": load_adapter(repo_g, adapter_g_file, config_g),
"config": config_g
})
# Encode prompt via T5
t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
t5_seq = t5_mod(t5_ids).last_hidden_state # (B, L, 768)
# Encode prompt via SDXL normally to get CLIP-L and CLIP-G outputs
prompt_embeds, pooled_prompt_embeds = pipe._encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
)
total_dim = prompt_embeds.shape[-1]
cond_tensor = prompt_embeds.clone()
for adapter_info in adapter_list:
adapter_model = adapter_info["adapter"]
adapter_config = adapter_info["config"]
clip_dim = adapter_config["clip"]["hidden_size"]
if clip_dim == 768:
clip_slice = cond_tensor[:, :, :768]
slice_start, slice_end = 0, 768
elif clip_dim == 1280:
clip_slice = cond_tensor[:, :, 768:2048] if total_dim >= 2048 else cond_tensor[:, :, 768:]
slice_start, slice_end = 768, 2048
else:
continue
anchor, delta_mean_adapter, log_sigma_adapter, _, _, _, g_pred_adapter, gate_adapter = adapter_model(t5_seq, clip_slice)
gate = gate_adapter * gate_prob
delta = (delta_mean_adapter + 0.0) * strength * gate
if delta.shape[1] != clip_slice.shape[1]:
delta = torch.nn.functional.interpolate(
delta.transpose(1, 2),
size=clip_slice.size(1),
mode="nearest"
).transpose(1, 2)
if use_anchor:
clip_slice = clip_slice * (1 - gate) + anchor * gate
if noise > 0:
clip_slice = clip_slice + torch.randn_like(clip_slice) * noise
cond_tensor[:, :, slice_start:slice_end] = (clip_slice + delta).type_as(cond_tensor)
pooled_embed = cond_tensor.mean(dim=1)
image = pipe(
prompt_embeds=cond_tensor,
pooled_prompt_embeds=pooled_embed,
negative_prompt_embeds=torch.zeros_like(cond_tensor),
negative_pooled_prompt_embeds=torch.zeros_like(pooled_embed),
num_inference_steps=20,
guidance_scale=5.0
).images[0]
return image
# ─── Gradio App ───────────────────────────────────────────────
with gr.Blocks(title="Dual Adapter T5β†’CLIP") as demo:
gr.Markdown("# 🧠 Dual Shunt Adapter β€’ SDXL Inference")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a futuristic control station")
adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter")
adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter")
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
use_anchor = gr.Checkbox(label="Use Anchor", value=True)
run_btn = gr.Button("Run")
with gr.Column():
out_img = gr.Image(label="Generated Image")
run_btn.click(
fn=infer,
inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
outputs=out_img
)
if __name__ == "__main__":
demo.launch(share=True)