File size: 6,289 Bytes
ce1661d
 
 
 
 
 
 
 
cc202cf
 
 
 
ce1661d
cc202cf
ce1661d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4826d04
 
 
 
 
 
ce1661d
 
 
 
 
 
 
 
 
 
b178d72
ce1661d
 
 
 
 
 
 
07d3299
ce1661d
 
 
 
85635cf
ce1661d
 
 
 
c2a80b4
ce1661d
 
00e1a69
ce1661d
 
 
00e1a69
ce1661d
 
00e1a69
ce1661d
 
00e1a69
ce1661d
 
00e1a69
ce1661d
 
 
 
 
00e1a69
ce1661d
 
 
 
00e1a69
 
 
 
 
 
 
 
 
 
 
ce1661d
 
 
 
 
 
00e1a69
ce1661d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import spaces
import torch
from diffusers import StableDiffusion3InstructPix2PixPipeline
import gradio as gr
import PIL.Image
import numpy as np
from PIL import Image, ImageOps
import os
# import transformers
# from transformers.utils.hub import move_cache
# transformers.utils.move_cache()
# move_cache()

pipe = StableDiffusion3InstructPix2PixPipeline.from_pretrained("bpathir1/RefEdit-SD3", torch_dtype=torch.float16).to("cuda")

@spaces.GPU(duration=120)
def generate(image, prompt, num_inference_steps=50, image_guidance_scale=1.5, guidance_scale=7.5, seed=255):

    seed = int(seed)
    generator = torch.manual_seed(seed)

    img = image.convert("RGB")
    
    desired_size = (512, 512)

    img = ImageOps.fit(img, desired_size, method=Image.LANCZOS, centering=(0.5, 0.5))

    image = pipe(
        prompt,
        image=img,
        mask_img=None,
        num_inference_steps=num_inference_steps,
        image_guidance_scale=image_guidance_scale,
        guidance_scale=guidance_scale,
        generator=generator
    ).images[0]

    return image


# Update the example list to remove mask-related entries
example_lists = [
    ['UltraEdit/images/example_images/4ppl2.jpg', "Add a flower on the t-shirt of the guy in the middle with dark jeans", 50, 1.5, 7.5, 3345],
    ['UltraEdit/images/example_images/cat2.jpg', "Add a green scarf to the right cat", 50, 1.5, 7.5, 3345],
    ['UltraEdit/images/example_images/3ppl2.jpg', "Add a flower bunch to the person with a red jacket", 50, 1.5, 7.5, 3345],
    ['UltraEdit/images/example_images/4ppl1.jpg', "Let the rightmost person wear a golden dress", 50, 1.5, 7.5, 123456],
    ['UltraEdit/images/example_images/bowls1.jpg', "Remove the bowl with some leaves in the middle", 50, 1.5, 7.5, 3345],
    ['UltraEdit/images/example_images/cat1.jpg', "Can we have a dog instead of the cat looking at the camera?", 50, 1.5, 7.5, 3345],
]

# Update the mask_ex_list to reflect the new example list structure
mask_ex_list = []
for exp in example_lists:
    re_list = [exp[0], exp[1], exp[2], exp[3], exp[4], exp[5]]
    mask_ex_list.append(re_list)


# Update the input for image upload to remove mask-related functionality
image_input = gr.Image(type="pil", label="Input Image")
prompt_input = gr.Textbox(label="Prompt")
num_inference_steps_input = gr.Slider(minimum=0, maximum=100, value=50, label="Number of Inference Steps")
image_guidance_scale_input = gr.Slider(minimum=0.0, maximum=2.5, value=1.5, label="Image Guidance Scale")
guidance_scale_input = gr.Slider(minimum=0.0, maximum=17.5, value=12.5, label="Guidance Scale")
seed_input = gr.Textbox(value="255", label="Random Seed")

inputs = [image_input, prompt_input, num_inference_steps_input, image_guidance_scale_input, guidance_scale_input, seed_input]
outputs = [gr.Image(label="Generated Image")]

article_html = """
<div style="text-align: center; max-width: 1000px; margin: 20px auto; font-family: Arial, sans-serif;">
  <h2 style="font-weight: 900; font-size: 2.5rem; margin-bottom: 0.5rem;">
    RefEdit-SD3 for Instruction-based Image Editing Model on Referring Expressions
  </h2>
  <div style="margin-bottom: 1rem;">
    <h3 style="font-weight: 500; font-size: 1.25rem; margin: 0;"></h3>
    <p style="font-weight: 400; font-size: 1rem; margin: 0.5rem 0;">
      Bimsara Pathiraja<sup>*</sup>, Maitreya Patel<sup>*</sup>, Shivam Singh, Yezhou Yang, Chitta Baral
    </p>
    <p style="font-weight: 400; font-size: 1rem; margin: 0;">
      Arizona State University
    </p>
  </div>
  <div style="margin: 1rem 0; display: flex; justify-content: center; gap: 1.5rem; flex-wrap: wrap;">
    <a href="https://huggingface.co/datasets/bpathir1/RefEdit" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
      <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Dataset_4M" style="height: 20px; vertical-align: middle;"> Dataset
    </a>
    <a href="https://refedit.vercel.app/" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
      <span style="font-size: 20px; vertical-align: middle;">🔗</span> Page
    </a>
    <a href="https://huggingface.co/bpathir1/RefEdit-SD3" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
      <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Checkpoint" style="height: 20px; vertical-align: middle;"> Checkpoint
    </a>
    <a href="https://github.com/bimsarapathiraja/refedit_private" style="display: flex; align-items: center; text-decoration: none; color: blue; font-weight: bold; gap: 0.5rem;">
      <img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub" style="height: 20px; vertical-align: middle;"> GitHub
    </a>
  </div>
  <div style="text-align: left; margin: 0 auto; font-size: 1rem; line-height: 1.5;">
    <p>
      <b>RefEdit</b> is a benchmark and method for improving instruction-based image editing models for referring expressions.
    </p>
  </div>
</div>
"""
# html = '''
#   <div style="text-align: left; margin-top: 2rem; font-size: 0.85rem; color: gray;">
#     <b>Limitations:</b>
#     <ul>
#       <li>We have not conducted any NSFW checks;</li>
#       <li>Due to the bias of the generated models, the model performance is still weak when dealing with high-frequency information such as <b>human facial expressions or text in the images</b>;</li>
#       <li>We unified the free-form and region-based image editing by adding an extra channel of the mask image to the dataset. When doing free-form image editing, the network receives a blank mask.</li>
#       <li>The generation result is sensitive to the guidance scale. For text guidance, based on experience, free-form image editing will perform better with a relatively low guidance score (7.5 or lower), while region-based image editing will perform better with a higher guidance score.</li>
#     </ul>
#   </div>
# '''

demo = gr.Interface(
    fn=generate,
    inputs=inputs,
    outputs=outputs,
    description=article_html,
    # article=html,
    examples=mask_ex_list,
    cache_examples = True,
    live = False
    
)

demo.queue().launch()