Spaces:
Runtime error
Runtime error
File size: 6,005 Bytes
ca066a9 403ae01 ca066a9 403ae01 7b42604 ca066a9 620a643 7b42604 ca066a9 7229198 ca066a9 7b42604 ca066a9 7b42604 403ae01 7b42604 ca066a9 403ae01 7b42604 ca066a9 403ae01 ca066a9 7b42604 403ae01 7b42604 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import DiffusionPipeline
from safetensors.torch import safe_open
from huggingface_hub import hf_hub_download
from two_stream_shunt_adapter import TwoStreamShuntAdapter
from adapter_config import T5_SHUNT_REPOS
# βββ Device & Model Setup βββββββββββββββββββββββββββββββββββββ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_mod = T5EncoderModel.from_pretrained("google/flan-t5-base").to(device).eval()
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype,
variant="fp16" if dtype == torch.float16 else None
).to(device)
# βββ Adapter Configs ββββββββββββββββββββββββββββββββββββββββββ
clip_l_opts = T5_SHUNT_REPOS["clip_l"]["shunts_available"]["shunt_list"]
clip_g_opts = T5_SHUNT_REPOS["clip_g"]["shunts_available"]["shunt_list"]
repo_l = T5_SHUNT_REPOS["clip_l"]["repo"]
repo_g = T5_SHUNT_REPOS["clip_g"]["repo"]
config_l = T5_SHUNT_REPOS["clip_l"]["config"]
config_g = T5_SHUNT_REPOS["clip_g"]["config"]
# βββ Loader βββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_adapter(repo, filename, config):
path = hf_hub_download(repo_id=repo, filename=filename)
model = TwoStreamShuntAdapter(config).eval()
tensors = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
model.load_state_dict(tensors)
model.to(device)
return model
# βββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββ
@torch.no_grad()
def infer(prompt, adapter_l_file, adapter_g_file, strength, noise, gate_prob, use_anchor):
adapter_list = []
# Load adapters with config
adapter_list.append({
"adapter": load_adapter(repo_l, adapter_l_file, config_l),
"config": config_l
})
adapter_list.append({
"adapter": load_adapter(repo_g, adapter_g_file, config_g),
"config": config_g
})
# Encode prompt via T5
t5_ids = t5_tok(prompt, return_tensors="pt").input_ids.to(device)
t5_seq = t5_mod(t5_ids).last_hidden_state # (B, L, 768)
# Encode prompt via SDXL normally to get CLIP-L and CLIP-G outputs
prompt_embeds, pooled_prompt_embeds = pipe._encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
)
total_dim = prompt_embeds.shape[-1]
cond_tensor = prompt_embeds.clone()
for adapter_info in adapter_list:
adapter_model = adapter_info["adapter"]
adapter_config = adapter_info["config"]
clip_dim = adapter_config["clip"]["hidden_size"]
if clip_dim == 768:
clip_slice = cond_tensor[:, :, :768]
slice_start, slice_end = 0, 768
elif clip_dim == 1280:
clip_slice = cond_tensor[:, :, 768:2048] if total_dim >= 2048 else cond_tensor[:, :, 768:]
slice_start, slice_end = 768, 2048
else:
continue
anchor, delta_mean_adapter, log_sigma_adapter, _, _, _, g_pred_adapter, gate_adapter = adapter_model(t5_seq, clip_slice)
gate = gate_adapter * gate_prob
delta = (delta_mean_adapter + 0.0) * strength * gate
if delta.shape[1] != clip_slice.shape[1]:
delta = torch.nn.functional.interpolate(
delta.transpose(1, 2),
size=clip_slice.size(1),
mode="nearest"
).transpose(1, 2)
if use_anchor:
clip_slice = clip_slice * (1 - gate) + anchor * gate
if noise > 0:
clip_slice = clip_slice + torch.randn_like(clip_slice) * noise
cond_tensor[:, :, slice_start:slice_end] = (clip_slice + delta).type_as(cond_tensor)
pooled_embed = cond_tensor.mean(dim=1)
image = pipe(
prompt_embeds=cond_tensor,
pooled_prompt_embeds=pooled_embed,
negative_prompt_embeds=torch.zeros_like(cond_tensor),
negative_pooled_prompt_embeds=torch.zeros_like(pooled_embed),
num_inference_steps=20,
guidance_scale=5.0
).images[0]
return image
# βββ Gradio App βββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(title="Dual Adapter T5βCLIP") as demo:
gr.Markdown("# π§ Dual Shunt Adapter β’ SDXL Inference")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a futuristic control station")
adapter_l = gr.Dropdown(choices=clip_l_opts, label="CLIP-L (768d) Adapter")
adapter_g = gr.Dropdown(choices=clip_g_opts, label="CLIP-G (1280d) Adapter")
strength = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="Adapter Strength")
noise = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Noise Injection")
gate_prob = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Gate Probability")
use_anchor = gr.Checkbox(label="Use Anchor", value=True)
run_btn = gr.Button("Run")
with gr.Column():
out_img = gr.Image(label="Generated Image")
run_btn.click(
fn=infer,
inputs=[prompt, adapter_l, adapter_g, strength, noise, gate_prob, use_anchor],
outputs=out_img
)
if __name__ == "__main__":
demo.launch(share=True)
|