OzzyGT's picture
OzzyGT HF Staff
link to fast inpaint
432e607
raw
history blame
4.47 kB
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, ControlNetUnionModel, DiffusionPipeline, TCDScheduler
def callback_cfg_cutoff(pipeline, step_index, timestep, callback_kwargs):
if step_index == int(pipeline.num_timesteps * 0.2):
prompt_embeds = callback_kwargs["prompt_embeds"]
prompt_embeds = prompt_embeds[-1:]
add_text_embeds = callback_kwargs["add_text_embeds"]
add_text_embeds = add_text_embeds[-1:]
add_time_ids = callback_kwargs["add_time_ids"]
add_time_ids = add_time_ids[-1:]
control_image = callback_kwargs["control_image"]
control_image[0] = control_image[0][-1:]
control_type = callback_kwargs["control_type"]
control_type = control_type[-1:]
pipeline._guidance_scale = 0.0
callback_kwargs["prompt_embeds"] = prompt_embeds
callback_kwargs["add_text_embeds"] = add_text_embeds
callback_kwargs["add_time_ids"] = add_time_ids
callback_kwargs["control_image"] = control_image
callback_kwargs["control_type"] = control_type
return callback_kwargs
MODELS = {
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}
controlnet_model = ControlNetUnionModel.from_pretrained(
"OzzyGT/controlnet-union-promax-sdxl-1.0", variant="fp16", torch_dtype=torch.float16
)
controlnet_model.to(device="cuda", dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
pipe = DiffusionPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=controlnet_model,
custom_pipeline="OzzyGT/custom_sdxl_cnet_union",
).to("cuda")
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
prompt = "high quality"
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt, "cuda")
@spaces.GPU(duration=16)
def fill_image(image, model_selection):
source = image["background"]
mask = image["layers"][0]
alpha_channel = mask.split()[3]
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
cnet_image = source.copy()
cnet_image.paste(0, (0, 0), binary_mask)
image = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
control_image=[cnet_image],
controlnet_conditioning_scale=[1.0],
control_mode=[7],
num_inference_steps=8,
guidance_scale=1.5,
callback_on_step_end=callback_cfg_cutoff,
callback_on_step_end_tensor_inputs=[
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
"control_image",
"control_type",
],
).images[0]
image = image.convert("RGBA")
cnet_image.paste(image, (0, 0), binary_mask)
yield source, cnet_image
def clear_result():
return gr.update(value=None)
title = """<h2 align="center">Diffusers Image Fill</h2>
<div align="center">Draw the mask over the subject you want to erase or change.</div>
<div align="center">
This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-image-fill'>Diffusers Image Fill</a>.
If you need a space where you can use prompts, please go to the <a href='https://huggingface.co/spaces/OzzyGT/diffusers-fast-inpaint'>Diffusers Fast Inpaint</a> space.
</div>
"""
with gr.Blocks() as demo:
gr.HTML(title)
run_button = gr.Button("Generate")
with gr.Row():
input_image = gr.ImageMask(
type="pil",
label="Input Image",
crop_size=(1024, 1024),
canvas_size=(1024, 1024),
layers=False,
sources=["upload"],
height=512,
)
result = gr.ImageSlider(
interactive=False,
label="Generated Image",
)
model_selection = gr.Dropdown(
choices=list(MODELS.keys()),
value="RealVisXL V5.0 Lightning",
label="Model",
)
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=fill_image,
inputs=[input_image, model_selection],
outputs=result,
)
demo.launch(share=False)