Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
from diffusers import StableDiffusion3InstructPix2PixPipeline | |
import gradio as gr | |
import PIL.Image | |
import numpy as np | |
from PIL import Image, ImageOps | |
import os | |
# import transformers | |
# from transformers.utils.hub import move_cache | |
# transformers.utils.move_cache() | |
# move_cache() | |
pipe = StableDiffusion3InstructPix2PixPipeline.from_pretrained("bpathir1/RefEdit-SD3", torch_dtype=torch.float16).to("cuda") | |
def generate(image, prompt, num_inference_steps=50, image_guidance_scale=1.5, guidance_scale=7.5, seed=255): | |
seed = int(seed) | |
generator = torch.manual_seed(seed) | |
img = image.convert("RGB") | |
desired_size = (512, 512) | |
img = ImageOps.fit(img, desired_size, method=Image.LANCZOS, centering=(0.5, 0.5)) | |
image = pipe( | |
prompt, | |
image=img, | |
mask_img=None, | |
num_inference_steps=num_inference_steps, | |
image_guidance_scale=image_guidance_scale, | |
guidance_scale=guidance_scale, | |
generator=generator | |
).images[0] | |
return image | |
# Update the example list to remove mask-related entries | |
example_lists = [ | |
['UltraEdit/images/example_images/4ppl2.jpg', "Add a flower on the t-shirt of the guy in the middle with dark jeans", 50, 1.5, 7.5, 3345], | |
['UltraEdit/images/example_images/cat2.jpg', "Add a green scarf to the right cat", 50, 1.5, 7.5, 3345], | |
['UltraEdit/images/example_images/3ppl2.jpg', "Add a flower bunch to the person with a red jacket", 50, 1.5, 7.5, 3345], | |
['UltraEdit/images/example_images/4ppl1.jpg', "Let the rightmost person wear a golden dress", 50, 1.5, 7.5, 123456], | |
['UltraEdit/images/example_images/bowls1.jpg', "Remove the bowl with some leaves in the middle", 50, 1.5, 7.5, 3345], | |
['UltraEdit/images/example_images/cat1.jpg', "Can we have a dog instead of the cat looking at the camera?", 50, 1.5, 7.5, 3345], | |
] | |
# Update the mask_ex_list to reflect the new example list structure | |
mask_ex_list = [] | |
for exp in example_lists: | |
re_list = [exp[0], exp[1], exp[2], exp[3], exp[4], exp[5]] | |
mask_ex_list.append(re_list) | |
# Update the input for image upload to remove mask-related functionality | |
image_input = gr.Image(type="pil", label="Input Image") | |
prompt_input = gr.Textbox(label="Prompt") | |
num_inference_steps_input = gr.Slider(minimum=0, maximum=100, value=50, label="Number of Inference Steps") | |
image_guidance_scale_input = gr.Slider(minimum=0.0, maximum=2.5, value=1.5, label="Image Guidance Scale") | |
guidance_scale_input = gr.Slider(minimum=0.0, maximum=17.5, value=12.5, label="Guidance Scale") | |
seed_input = gr.Textbox(value="255", label="Random Seed") | |
inputs = [image_input, prompt_input, num_inference_steps_input, image_guidance_scale_input, guidance_scale_input, seed_input] | |
outputs = [gr.Image(label="Generated Image")] | |
article_html = """ | |
<div style="text-align: center; max-width: 1000px; margin: 20px auto; font-family: Arial, sans-serif;"> | |
<h2 style="font-weight: 900; font-size: 2.5rem; margin-bottom: 0.5rem;"> | |
RefEdit-SD3 for Instruction-based Image Editing Model on Referring Expressions | |
</h2> | |
<div style="margin-bottom: 1rem;"> | |
<h3 style="font-weight: 500; font-size: 1.25rem; margin: 0;"></h3> | |
<p style="font-weight: 400; font-size: 1rem; margin: 0.5rem 0;"> | |
Bimsara Pathiraja<sup>*</sup>, Maitreya Patel<sup>*</sup>, Shivam Singh, Yezhou Yang, Chitta Baral | |
</p> | |
<p style="font-weight: 400; font-size: 1rem; margin: 0;"> | |
Arizona State University | |
</p> | |
</div> | |
<div style="margin: 1rem 0; display: flex; justify-content: center; gap: 1.5rem; flex-wrap: wrap;"> | |
<a href="https://huggingface.co/datasets/bpathir1/RefEdit" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;"> | |
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Dataset_4M" style="height: 20px; vertical-align: middle;"> Dataset | |
</a> | |
<a href="https://refedit.vercel.app/" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;"> | |
<span style="font-size: 20px; vertical-align: middle;">🔗</span> Page | |
</a> | |
<a href="https://huggingface.co/bpathir1/RefEdit-SD3" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;"> | |
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Checkpoint" style="height: 20px; vertical-align: middle;"> Checkpoint | |
</a> | |
<a href="https://github.com/bimsarapathiraja/refedit_private" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;"> | |
<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub" style="height: 20px; vertical-align: middle;"> GitHub | |
</a> | |
</div> | |
<div style="text-align: left; margin: 0 auto; font-size: 1rem; line-height: 1.5;"> | |
<p> | |
<b>RefEdit</b> is a benchmark and method for improving instruction-based image editing models for referring expressions. | |
</p> | |
</div> | |
</div> | |
""" | |
# html = ''' | |
# <div style="text-align: left; margin-top: 2rem; font-size: 0.85rem; color: gray;"> | |
# <b>Limitations:</b> | |
# <ul> | |
# <li>We have not conducted any NSFW checks;</li> | |
# <li>Due to the bias of the generated models, the model performance is still weak when dealing with high-frequency information such as <b>human facial expressions or text in the images</b>;</li> | |
# <li>We unified the free-form and region-based image editing by adding an extra channel of the mask image to the dataset. When doing free-form image editing, the network receives a blank mask.</li> | |
# <li>The generation result is sensitive to the guidance scale. For text guidance, based on experience, free-form image editing will perform better with a relatively low guidance score (7.5 or lower), while region-based image editing will perform better with a higher guidance score.</li> | |
# </ul> | |
# </div> | |
# ''' | |
demo = gr.Interface( | |
fn=generate, | |
inputs=inputs, | |
outputs=outputs, | |
description=article_html, | |
# article=html, | |
examples=mask_ex_list, | |
cache_examples = True, | |
live = False | |
) | |
demo.queue().launch() |