import torch import gradio as gr import numpy as np import matplotlib.pyplot as plt from transformers import T5Tokenizer, T5EncoderModel from diffusers import DiffusionPipeline from safetensors.torch import safe_open from huggingface_hub import hf_hub_download from two_stream_shunt_adapter import TwoStreamShuntAdapter from adapter_config import T5_SHUNT_REPOS # ─── Device & Model Setup ───────────────────────────────────── device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if torch.cuda.is_available() else torch.float32 t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, variant="fp16" if dtype == torch.float16 else None ).to(device) # ─── Adapter Configs ────────────────────────────────────────── clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] config_l = T5_SHUNT_REPOS["clip_l"]["config"] config_g = T5_SHUNT_REPOS["clip_g"]["config"] # ─── Loader ─────────────────────────────────────────────────── def load_adapter(repo, filename, config): path = hf_hub_download(repo_id=repo, filename=filename) model = TwoStreamShuntAdapter(config).eval() tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) model.load_state_dict(tensors) model.to(device) return model # ─── Inference ──────────────────────────────────────────────── @torch.no_grad() def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor): adapter_list = [] # Load adapters with config adapter_list.append({ "adapter": load_adapter(repo_l, adapter_l_file, config_l), "config": config_l }) adapter_list.append({ "adapter": load_adapter(repo_g, adapter_g_file, config_g), "config": config_g }) # Encode prompt via T5 t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device) t5_seq = t5_mod(t5_ids).last_hidden_state # (B, L, 768) # Encode prompt via SDXL normally to get CLIP-L and CLIP-G outputs prompt_embeds, pooled_prompt_embeds = pipe._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=False, ) total_dim = prompt_embeds.shape[-1] cond_tensor = prompt_embeds.clone() for adapter_info in adapter_list: adapter_model = adapter_info["adapter"] adapter_config = adapter_info["config"] clip_dim = adapter_config["clip"]["hidden_size"] if clip_dim == 768: clip_slice = cond_tensor[:, :, :768] slice_start, slice_end = 0, 768 elif clip_dim == 1280: clip_slice = cond_tensor[:, :, 768:2048] if total_dim >= 2048 else cond_tensor[:, :, 768:] slice_start, slice_end = 768, 2048 else: continue anchor, delta_mean_adapter, log_sigma_adapter, _, _, _, g_pred_adapter, gate_adapter = adapter_model(t5_seq, clip_slice) gate = gate_adapter * gate_prob delta = (delta_mean_adapter + 0.0) * strength * gate if delta.shape[1] != clip_slice.shape[1]: delta = torch.nn.functional.interpolate( delta.transpose(1, 2), size=clip_slice.size(1), mode="nearest" ).transpose(1, 2) if use_anchor: clip_slice = clip_slice * (1 - gate) + anchor * gate if noise > 0: clip_slice = clip_slice + torch.randn_like(clip_slice) * noise cond_tensor[:, :, slice_start:slice_end] = (clip_slice + delta).type_as(cond_tensor) pooled_embed = cond_tensor.mean(dim=1) image = pipe( prompt_embeds=cond_tensor, pooled_prompt_embeds=pooled_embed, negative_prompt_embeds=torch.zeros_like(cond_tensor), negative_pooled_prompt_embeds=torch.zeros_like(pooled_embed), num_inference_steps=20, guidance_scale=5.0 ).images[0] return image # ─── Gradio App ─────────────────────────────────────────────── with gr.Blocks(title="Dual Adapter T5→CLIP") as demo: gr.Markdown("# 🧠 Dual Shunt Adapter • SDXL Inference") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="a futuristic control station") adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter") adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter") strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") use_anchor = gr.Checkbox(label="Use Anchor", value=True) run_btn = gr.Button("Run") with gr.Column(): out_img = gr.Image(label="Generated Image") run_btn.click( fn=infer, inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor], outputs=out_img ) if __name__ == "__main__": demo.launch(share=True)