Spaces:
Running
Running
File size: 5,149 Bytes
863f08e 8a52ce7 1f7518e 8a52ce7 1f7518e 8a52ce7 1f7518e 8a52ce7 1f7518e 8a52ce7 077767b 1f7518e 8a52ce7 1f7518e 8a52ce7 1f7518e 8a52ce7 1f7518e 8a52ce7 c9f36bf 8a52ce7 c9f36bf 8a52ce7 c9f36bf 8a52ce7 e61c05b 8a52ce7 863f08e 8a52ce7 863f08e 8a52ce7 863f08e 8a52ce7 e61c05b 8a52ce7 e61c05b 8a52ce7 863413c 8a52ce7 863413c 8a52ce7 863f08e 8a52ce7 863f08e 8a52ce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from peft import PeftModel, LoraConfig
import os
def get_lora_sd_pipeline(
ckpt_dir='./lora_logos',
base_model_name_or_path=None,
dtype=torch.float16,
adapter_name="default"
):
unet_sub_dir = os.path.join(ckpt_dir, "unet")
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
base_model_name_or_path = config.base_model_name_or_path
if base_model_name_or_path is None:
raise ValueError("Please specify the base model name or path")
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
if os.path.exists(text_encoder_sub_dir):
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
return pipe
def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
tokens = tokenizer(prompt, truncation=False, return_tensors="pt")["input_ids"]
chunks = [tokens[:, i:i + max_length] for i in range(0, tokens.shape[1], max_length)]
with torch.no_grad():
embeds = [text_encoder(chunk.to(text_encoder.device))[0] for chunk in chunks]
return torch.cat(embeds, dim=1)
def align_embeddings(prompt_embeds, negative_prompt_embeds):
max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id_default = "CompVis/stable-diffusion-v1-4"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
pipe_default = get_lora_sd_pipeline(ckpt_dir='./lora_logos', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
def infer(
prompt,
negative_prompt,
width=512,
height=512,
num_inference_steps=20,
model_id='CompVis/stable-diffusion-v1-4',
seed=42,
guidance_scale=7.0,
lora_scale=0.5
):
generator = torch.Generator(device).manual_seed(seed)
print(prompt)
print(type(prompt))
print(negative_prompt)
print(type(negative_prompt))
if model_id != model_id_default:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
else:
pipe = pipe_default
prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
pipe.fuse_lora(lora_scale=lora_scale)
params = {
'prompt_embeds': prompt_embeds,
'negative_prompt_embeds': negative_prompt_embeds,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
'width': width,
'height': height,
'generator': generator,
}
return pipe(**params).images[0]
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# DEMO Text-to-Image")
model_id = gr.Textbox(label="Model ID", value=model_id_default)
prompt = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative prompt")
seed = gr.Number(label="Seed", value=42)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, value=7.0)
lora_scale = gr.Slider(label="LoRA scale", minimum=0.0, maximum=1.0, value=0.5)
num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, value=20)
with gr.Accordion("Optional Settings", open=False):
width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=32)
height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=32)
run_button = gr.Button("Run")
result = gr.Image(label="Result")
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
width,
height,
num_inference_steps,
model_id, seed,
guidance_scale,
lora_scale
],
outputs=result)
if __name__ == "__main__":
demo.launch()
|