RefEdit-SD3 / app.py
bimsarapathiraja
Change desc
00e1a69
raw
history blame
6.25 kB
import spaces
import torch
from diffusers import StableDiffusion3InstructPix2PixPipeline
import gradio as gr
import PIL.Image
import numpy as np
from PIL import Image, ImageOps
import os
# import transformers
# from transformers.utils.hub import move_cache
# transformers.utils.move_cache()
# move_cache()
pipe = StableDiffusion3InstructPix2PixPipeline.from_pretrained("bpathir1/RefEdit-SD3", torch_dtype=torch.float16).to("cuda")
@spaces.GPU(duration=120)
def generate(image, prompt, num_inference_steps=50, image_guidance_scale=1.5, guidance_scale=7.5, seed=255):
seed = int(seed)
generator = torch.manual_seed(seed)
img = image.convert("RGB")
desired_size = (512, 512)
img = ImageOps.fit(img, desired_size, method=Image.LANCZOS, centering=(0.5, 0.5))
image = pipe(
prompt,
image=img,
mask_img=None,
num_inference_steps=num_inference_steps,
image_guidance_scale=image_guidance_scale,
guidance_scale=guidance_scale,
generator=generator
).images[0]
return image
# Update the example list to remove mask-related entries
example_lists = [
['UltraEdit/images/example_images/4ppl2.jpg', "Add a flower on the t-shirt of the guy in the middle with dark jeans", 50, 1.5, 7.5, 255],
['UltraEdit/images/example_images/cat2.jpg', "Add a green scarf to the right cat", 50, 1.5, 7.5, 255],
['UltraEdit/images/example_images/3ppl2.jpg', "Add a flower bunch to the person with a red jacket", 50, 1.5, 7.5, 255],
['UltraEdit/images/example_images/4ppl1.jpg', "Let the rightmost person wear a golden dress", 50, 1.5, 7.5, 255],
['UltraEdit/images/example_images/bowls1.jpg', "Remove the bowl with some leaves in the middle", 50, 1.5, 7.5, 255],
['UltraEdit/images/example_images/cat1.jpg', "Can we have a dog instead of the cat looking at the camera?", 50, 1.5, 7.5, 255],
]
# Update the mask_ex_list to reflect the new example list structure
mask_ex_list = []
for exp in example_lists:
re_list = [exp[0], exp[1], exp[2], exp[3], exp[4], exp[5]]
mask_ex_list.append(re_list)
# Update the input for image upload to remove mask-related functionality
image_input = gr.Image(type="pil", label="Input Image")
prompt_input = gr.Textbox(label="Prompt")
num_inference_steps_input = gr.Slider(minimum=0, maximum=100, value=50, label="Number of Inference Steps")
image_guidance_scale_input = gr.Slider(minimum=0.0, maximum=2.5, value=1.5, label="Image Guidance Scale")
guidance_scale_input = gr.Slider(minimum=0.0, maximum=17.5, value=12.5, label="Guidance Scale")
seed_input = gr.Textbox(value="255", label="Random Seed")
inputs = [image_input, prompt_input, num_inference_steps_input, image_guidance_scale_input, guidance_scale_input, seed_input]
outputs = [gr.Image(label="Generated Image")]
article_html = """
<div style="text-align: center; max-width: 1000px; margin: 20px auto; font-family: Arial, sans-serif;">
<h2 style="font-weight: 900; font-size: 2.5rem; margin-bottom: 0.5rem;">
RefEdit-SD3 for Fine-Grained Image Editing
</h2>
<div style="margin-bottom: 1rem;">
<h3 style="font-weight: 500; font-size: 1.25rem; margin: 0;"></h3>
<p style="font-weight: 400; font-size: 1rem; margin: 0.5rem 0;">
Bimsara Pathiraja<sup>1*</sup>, Maitreya Patel<sup>2*</sup>, Shivam Singh, Yezhou Yang, Chitta Baral
</p>
<p style="font-weight: 400; font-size: 1rem; margin: 0;">
Arizona State University
</p>
</div>
<div style="margin: 1rem 0; display: flex; justify-content: center; gap: 1.5rem; flex-wrap: wrap;">
<a href="https://huggingface.co/datasets/bpathir1/RefEdit" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Dataset_4M" style="height: 20px; vertical-align: middle;"> Dataset
</a>
<a href="https://refedit.vercel.app/" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
<span style="font-size: 20px; vertical-align: middle;">🔗</span> Page
</a>
<a href="https://huggingface.co/bpathir1/RefEdit-SD3" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Checkpoint" style="height: 20px; vertical-align: middle;"> Checkpoint
</a>
<a href="https://github.com/bimsarapathiraja/refedit_private" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub" style="height: 20px; vertical-align: middle;"> GitHub
</a>
</div>
<div style="text-align: left; margin: 0 auto; font-size: 1rem; line-height: 1.5;">
<p>
<b>RefEdit</b> is a benchmark and method for improving instruction-based image editing models for referring expressions.
</p>
</div>
</div>
"""
# html = '''
# <div style="text-align: left; margin-top: 2rem; font-size: 0.85rem; color: gray;">
# <b>Limitations:</b>
# <ul>
# <li>We have not conducted any NSFW checks;</li>
# <li>Due to the bias of the generated models, the model performance is still weak when dealing with high-frequency information such as <b>human facial expressions or text in the images</b>;</li>
# <li>We unified the free-form and region-based image editing by adding an extra channel of the mask image to the dataset. When doing free-form image editing, the network receives a blank mask.</li>
# <li>The generation result is sensitive to the guidance scale. For text guidance, based on experience, free-form image editing will perform better with a relatively low guidance score (7.5 or lower), while region-based image editing will perform better with a higher guidance score.</li>
# </ul>
# </div>
# '''
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
description=article_html,
# article=html,
examples=mask_ex_list,
cache_examples = True,
live = False
)
demo.queue().launch()