RefEdit-SD3 / app.py
bpathir1's picture
Update app.py
c2a80b4 verified
import spaces
import torch
from diffusers import StableDiffusion3InstructPix2PixPipeline
import gradio as gr
import PIL.Image
import numpy as np
from PIL import Image, ImageOps
import os
# import transformers
# from transformers.utils.hub import move_cache
# transformers.utils.move_cache()
# move_cache()
pipe = StableDiffusion3InstructPix2PixPipeline.from_pretrained("bpathir1/RefEdit-SD3", torch_dtype=torch.float16).to("cuda")
@spaces.GPU(duration=120)
def generate(image, prompt, num_inference_steps=50, image_guidance_scale=1.5, guidance_scale=7.5, seed=255):
seed = int(seed)
generator = torch.manual_seed(seed)
img = image.convert("RGB")
desired_size = (512, 512)
img = ImageOps.fit(img, desired_size, method=Image.LANCZOS, centering=(0.5, 0.5))
image = pipe(
prompt,
image=img,
mask_img=None,
num_inference_steps=num_inference_steps,
image_guidance_scale=image_guidance_scale,
guidance_scale=guidance_scale,
generator=generator
).images[0]
return image
# Update the example list to remove mask-related entries
example_lists = [
['UltraEdit/images/example_images/4ppl2.jpg', "Add a flower on the t-shirt of the guy in the middle with dark jeans", 50, 1.5, 7.5, 3345],
['UltraEdit/images/example_images/cat2.jpg', "Add a green scarf to the right cat", 50, 1.5, 7.5, 3345],
['UltraEdit/images/example_images/3ppl2.jpg', "Add a flower bunch to the person with a red jacket", 50, 1.5, 7.5, 3345],
['UltraEdit/images/example_images/4ppl1.jpg', "Let the rightmost person wear a golden dress", 50, 1.5, 7.5, 123456],
['UltraEdit/images/example_images/bowls1.jpg', "Remove the bowl with some leaves in the middle", 50, 1.5, 7.5, 3345],
['UltraEdit/images/example_images/cat1.jpg', "Can we have a dog instead of the cat looking at the camera?", 50, 1.5, 7.5, 3345],
]
# Update the mask_ex_list to reflect the new example list structure
mask_ex_list = []
for exp in example_lists:
re_list = [exp[0], exp[1], exp[2], exp[3], exp[4], exp[5]]
mask_ex_list.append(re_list)
# Update the input for image upload to remove mask-related functionality
image_input = gr.Image(type="pil", label="Input Image")
prompt_input = gr.Textbox(label="Prompt")
num_inference_steps_input = gr.Slider(minimum=0, maximum=100, value=50, label="Number of Inference Steps")
image_guidance_scale_input = gr.Slider(minimum=0.0, maximum=2.5, value=1.5, label="Image Guidance Scale")
guidance_scale_input = gr.Slider(minimum=0.0, maximum=17.5, value=12.5, label="Guidance Scale")
seed_input = gr.Textbox(value="255", label="Random Seed")
inputs = [image_input, prompt_input, num_inference_steps_input, image_guidance_scale_input, guidance_scale_input, seed_input]
outputs = [gr.Image(label="Generated Image")]
article_html = """
<div style="text-align: center; max-width: 1000px; margin: 20px auto; font-family: Arial, sans-serif;">
<h2 style="font-weight: 900; font-size: 2.5rem; margin-bottom: 0.5rem;">
RefEdit-SD3 for Instruction-based Image Editing Model on Referring Expressions
</h2>
<div style="margin-bottom: 1rem;">
<h3 style="font-weight: 500; font-size: 1.25rem; margin: 0;"></h3>
<p style="font-weight: 400; font-size: 1rem; margin: 0.5rem 0;">
Bimsara Pathiraja<sup>*</sup>, Maitreya Patel<sup>*</sup>, Shivam Singh, Yezhou Yang, Chitta Baral
</p>
<p style="font-weight: 400; font-size: 1rem; margin: 0;">
Arizona State University
</p>
</div>
<div style="margin: 1rem 0; display: flex; justify-content: center; gap: 1.5rem; flex-wrap: wrap;">
<a href="https://huggingface.co/datasets/bpathir1/RefEdit" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Dataset_4M" style="height: 20px; vertical-align: middle;"> Dataset
</a>
<a href="https://refedit.vercel.app/" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
<span style="font-size: 20px; vertical-align: middle;">🔗</span> Page
</a>
<a href="https://huggingface.co/bpathir1/RefEdit-SD3" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Checkpoint" style="height: 20px; vertical-align: middle;"> Checkpoint
</a>
<a href="https://github.com/bimsarapathiraja/refedit_private" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub" style="height: 20px; vertical-align: middle;"> GitHub
</a>
</div>
<div style="text-align: left; margin: 0 auto; font-size: 1rem; line-height: 1.5;">
<p>
<b>RefEdit</b> is a benchmark and method for improving instruction-based image editing models for referring expressions.
</p>
</div>
</div>
"""
# html = '''
# <div style="text-align: left; margin-top: 2rem; font-size: 0.85rem; color: gray;">
# <b>Limitations:</b>
# <ul>
# <li>We have not conducted any NSFW checks;</li>
# <li>Due to the bias of the generated models, the model performance is still weak when dealing with high-frequency information such as <b>human facial expressions or text in the images</b>;</li>
# <li>We unified the free-form and region-based image editing by adding an extra channel of the mask image to the dataset. When doing free-form image editing, the network receives a blank mask.</li>
# <li>The generation result is sensitive to the guidance scale. For text guidance, based on experience, free-form image editing will perform better with a relatively low guidance score (7.5 or lower), while region-based image editing will perform better with a higher guidance score.</li>
# </ul>
# </div>
# '''
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
description=article_html,
# article=html,
examples=mask_ex_list,
cache_examples = True,
live = False
)
demo.queue().launch()