testdemo / app.py
xwwu's picture
Rename app-2.py to app.py
1eafe56
import gradio as gr
import sys
import torch
from PIL import Image
import numpy as np
from io import BytesIO
import os
from diffusers.utils import load_image
from diffusers import ControlNetModel
import numpy as np
import torch
from diffusers.image_processor import VaeImageProcessor
from PIL import Image
from pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
"Salesforce/blipdiffusion-controlnet"
)
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint")
blip_diffusion_pipe.controlnet = controlnet
blip_diffusion_pipe.to(device)
def make_inpaint_condition(image, image_mask):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
image[image_mask > 0.5] = -1 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
css='''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
.image_upload{min-height:500px}
.image_upload [data-testid="image"], .image_upload [data-testid="image"] > div{min-height: 500px}
.image_upload [data-testid="target"], .image_upload [data-testid="target"] > div{min-height: 500px}
.image_upload .touch-none{display: flex}
#output_image{min-height:500px;max-height=500px;}
'''
def create_demo():
# load information from users
HEIGHT, WIDTH=512,512
with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace","monospace"],
primary_hue="lime",
secondary_hue="emerald",
neutral_hue="slate",
), css=css) as demo:
gr.Markdown('# BLIP-Diffusion')
with gr.Accordion('Instructions', open=False):
gr.Markdown('1. Upload src image and draw mask')
gr.Markdown('2. Upload tgt image')
gr.Markdown('3. Input name of tgt object and description')
gr.Markdown('4. Click `Generate` when it is ready!')
with gr.Group():
with gr.Box():
with gr.Column():
with gr.Row() as main_blocks:
#
with gr.Column() as step_1:
gr.Markdown('### Source Input and Add Mask')
image = gr.Image(source='upload',
shape=[HEIGHT,WIDTH],
type='pil',#numpy',
elem_classes="image_upload",
label='Source Image',
tool='sketch',
brush_radius=60).style(height=500)
src_input=image
text_prompt = gr.Textbox(label='Prompt')
run_button = gr.Button(label='Generate', value='Generate', variant="primary")
#
with gr.Column() as step_2:
gr.Markdown('### Target Input')
target = gr.Image(source='upload',
shape=[HEIGHT,WIDTH],
type='pil',#numpy',
elem_classes="image_upload",
label='Target Image'
).style(height=500)
tgt_input=target
style_subject = gr.Textbox(label='Target Object')
with gr.Row() as output_blocks:
with gr.Column() as output_step:
gr.Markdown('### Output')
output_image = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="output_image",
).style(height=500,containter=True)
with gr.Accordion('Advanced options', open=False):
num_inference_steps = gr.Slider(label='Steps',
minimum=1,
maximum=100,
value=50,
step=1)
guidance_scale = gr.Slider(label='Text Guidance Scale',
minimum=0.1,
maximum=30.0,
value=7.5,
step=0.1)
seed = gr.Slider(label='Seed',
minimum=-1,
maximum=2147483647,
step=1,
randomize=True)
# Model
inputs = [
src_input,
tgt_input,
text_prompt,
style_subject,
num_inference_steps,
guidance_scale,
seed,
]
def generate(src_input,
tgt_input,
text_prompt,
style_subject,
num_inference_steps,
guidance_scale,
seed,
):
if src_input is None or tgt_input is None:
gr.Error("You must upload an image first.")
return {output_image : None,}
# model part
tgt_subject = style_subject
generator = torch.Generator(device="cpu").manual_seed(seed)
init_image = src_input['image']
cldm_cond_image = src_input['mask']
control_image = make_inpaint_condition(init_image, cldm_cond_image)
style_image = tgt_input
negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
output = blip_diffusion_pipe(
text_prompt,
style_image,
control_image,
style_subject,
tgt_subject,
generator=generator,
image=init_image,
mask_image=cldm_cond_image,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
neg_prompt=negative_prompt,
height=HEIGHT,
width=WIDTH,
).images
return {output_image : output,}
run_button.click(fn=generate, inputs=inputs, outputs=[output_image])
return demo
if __name__ == '__main__':
demo = create_demo()
demo.queue().launch()