Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from transformers import T5Tokenizer, T5EncoderModel | |
from diffusers import DiffusionPipeline | |
from safetensors.torch import safe_open | |
from huggingface_hub import hf_hub_download | |
from two_stream_shunt_adapter import TwoStreamShuntAdapter | |
from adapter_config import T5_SHUNT_REPOS | |
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base") | |
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval() | |
pipe = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, | |
variant="fp16" if dtype == torch.float16 else None | |
).to(device) | |
# βββ Adapter Configs ββββββββββββββββββββββββββββββββββββββββββ | |
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"] | |
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"] | |
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"] | |
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"] | |
config_l = T5_SHUNT_REPOS["clip_l"]["config"] | |
config_g = T5_SHUNT_REPOS["clip_g"]["config"] | |
# βββ Loader βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def load_adapter(repo, filename, config): | |
path = hf_hub_download(repo_id=repo, filename=filename) | |
model = TwoStreamShuntAdapter(config).eval() | |
tensors = {} | |
with safe_open(path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
tensors[key] = f.get_tensor(key) | |
model.load_state_dict(tensors) | |
model.to(device) | |
return model | |
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ | |
def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor): | |
adapter_list = [] | |
# Load adapters with config | |
adapter_list.append({ | |
"adapter": load_adapter(repo_l, adapter_l_file, config_l), | |
"config": config_l | |
}) | |
adapter_list.append({ | |
"adapter": load_adapter(repo_g, adapter_g_file, config_g), | |
"config": config_g | |
}) | |
# Encode prompt via T5 | |
t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device) | |
t5_seq = t5_mod(t5_ids).last_hidden_state # (B, L, 768) | |
# Encode prompt via SDXL normally to get CLIP-L and CLIP-G outputs | |
prompt_embeds, pooled_prompt_embeds = pipe._encode_prompt( | |
prompt=prompt, | |
device=device, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=False, | |
) | |
total_dim = prompt_embeds.shape[-1] | |
cond_tensor = prompt_embeds.clone() | |
for adapter_info in adapter_list: | |
adapter_model = adapter_info["adapter"] | |
adapter_config = adapter_info["config"] | |
clip_dim = adapter_config["clip"]["hidden_size"] | |
if clip_dim == 768: | |
clip_slice = cond_tensor[:, :, :768] | |
slice_start, slice_end = 0, 768 | |
elif clip_dim == 1280: | |
clip_slice = cond_tensor[:, :, 768:2048] if total_dim >= 2048 else cond_tensor[:, :, 768:] | |
slice_start, slice_end = 768, 2048 | |
else: | |
continue | |
anchor, delta_mean_adapter, log_sigma_adapter, _, _, _, g_pred_adapter, gate_adapter = adapter_model(t5_seq, clip_slice) | |
gate = gate_adapter * gate_prob | |
delta = (delta_mean_adapter + 0.0) * strength * gate | |
if delta.shape[1] != clip_slice.shape[1]: | |
delta = torch.nn.functional.interpolate( | |
delta.transpose(1, 2), | |
size=clip_slice.size(1), | |
mode="nearest" | |
).transpose(1, 2) | |
if use_anchor: | |
clip_slice = clip_slice * (1 - gate) + anchor * gate | |
if noise > 0: | |
clip_slice = clip_slice + torch.randn_like(clip_slice) * noise | |
cond_tensor[:, :, slice_start:slice_end] = (clip_slice + delta).type_as(cond_tensor) | |
pooled_embed = cond_tensor.mean(dim=1) | |
image = pipe( | |
prompt_embeds=cond_tensor, | |
pooled_prompt_embeds=pooled_embed, | |
negative_prompt_embeds=torch.zeros_like(cond_tensor), | |
negative_pooled_prompt_embeds=torch.zeros_like(pooled_embed), | |
num_inference_steps=20, | |
guidance_scale=5.0 | |
).images[0] | |
return image | |
# βββ Gradio App βββββββββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(title="Dual Adapter T5βCLIP") as demo: | |
gr.Markdown("# π§ Dual Shunt Adapter β’ SDXL Inference") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="a futuristic control station") | |
adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter") | |
adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter") | |
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength") | |
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection") | |
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability") | |
use_anchor = gr.Checkbox(label="Use Anchor", value=True) | |
run_btn = gr.Button("Run") | |
with gr.Column(): | |
out_img = gr.Image(label="Generated Image") | |
run_btn.click( | |
fn=infer, | |
inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor], | |
outputs=out_img | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |