Spaces:
Sleeping
Sleeping
import base64 | |
import os | |
import pdb | |
import random | |
import sys | |
import time | |
from io import BytesIO | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import torchvision.transforms.functional as TF | |
from PIL import Image | |
from torchvision import transforms | |
from src.img2skt import image_to_sketch_gif | |
from src.model import make_1step_sched | |
from src.pix2pix_turbo import Pix2Pix_Turbo | |
model = Pix2Pix_Turbo("sketch_to_image_stochastic") | |
style_list = [ | |
{ | |
"name": "No Style", | |
"prompt": "{prompt}", | |
}, | |
{ | |
"name": "Cinematic", | |
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", | |
}, | |
{ | |
"name": "3D Model", | |
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", | |
}, | |
{ | |
"name": "Anime", | |
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", | |
}, | |
{ | |
"name": "Digital Art", | |
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", | |
}, | |
{ | |
"name": "Photographic", | |
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", | |
}, | |
{ | |
"name": "Pixel art", | |
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", | |
}, | |
{ | |
"name": "Fantasy art", | |
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", | |
}, | |
{ | |
"name": "Neonpunk", | |
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", | |
}, | |
{ | |
"name": "Manga", | |
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", | |
}, | |
] | |
styles = {k["name"]: k["prompt"] for k in style_list} | |
STYLE_NAMES = list(styles.keys()) | |
DEFAULT_STYLE_NAME = "Manga" | |
MAX_SEED = np.iinfo(np.int32).max | |
HEIGHT = 512 # Display height | |
WIDTH = 512 # Display width | |
PROC_WIDTH = 512 # Processing width | |
PROC_HEIGHT = 512 # Processing height | |
ITER_DELAY = 1.0 | |
# Create a white background image | |
def create_white_background(width, height): | |
return Image.new("RGB", (width, height), color="white") | |
white_background = create_white_background(WIDTH, HEIGHT) | |
def make_button_and_slider_unclickable(): | |
# Disable the button and slider | |
return ( | |
gr.Button(interactive=False), | |
gr.Slider( | |
interactive=False, | |
), | |
) | |
def make_button_and_slider_clickable(): | |
# Enable the button and slider | |
return ( | |
gr.Button(interactive=True), | |
gr.Slider( | |
interactive=True, | |
), | |
) | |
def run(image, prompt, prompt_template, style_name, seed, val_r): | |
image = image["composite"] | |
if image.size != (PROC_WIDTH, PROC_HEIGHT): | |
image = image.resize((PROC_WIDTH, PROC_HEIGHT)) | |
prompt = prompt_template.replace("{prompt}", prompt) | |
image = image.convert("RGB") | |
image = Image.fromarray(255 - np.array(image)) | |
image_t = TF.to_tensor(image) > 0.5 | |
with torch.no_grad(): | |
c_t = image_t.unsqueeze(0).cuda().float() | |
torch.manual_seed(seed) | |
B, C, H, W = c_t.shape | |
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) | |
output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) | |
output_pil = TF.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) | |
if output_pil.size != (WIDTH, HEIGHT): | |
output_pil = output_pil.resize((WIDTH, HEIGHT)) | |
return output_pil | |
def clear_image_editor(): | |
return ( | |
{"background": white_background, "layers": None, "composite": None}, | |
gr.Image( | |
value=None, | |
), | |
gr.Image( | |
value=None, | |
), | |
gr.State([]), | |
gr.Slider( | |
maximum=1, | |
value=0, | |
interactive=False, | |
), | |
gr.Button(interactive=False), | |
) | |
def apply_func_click(frames, frame_selector): | |
# Apply the selected frame to the sketchpad | |
try: | |
selected_frame = frames[int(frame_selector)] | |
return { | |
"background": white_background, | |
"layers": [selected_frame], | |
"composite": None, | |
} | |
except Exception as e: | |
pass | |
def frame_selector_change(frame_idx, frames): | |
try: | |
frame_idx = int(frame_idx) | |
frame = frames[frame_idx] | |
return frame | |
except Exception as e: | |
pass | |
with gr.Blocks() as demo: | |
gr.Markdown("# Sketch to Image Demo Augmented with Sketch Generation") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image = gr.Sketchpad( | |
value={ | |
"background": white_background, | |
"layers": None, | |
"composite": white_background, | |
}, | |
image_mode="L", | |
type="pil", | |
sources=None, | |
# container=True, | |
label="Sketchpad", | |
show_label=True, | |
show_download_button=True, | |
# show_share_button=True, | |
interactive=True, | |
layers=False, | |
# height="80vw", | |
canvas_size=(WIDTH, HEIGHT), | |
show_fullscreen_button=False, | |
brush=gr.Brush( | |
colors=["#000000"], | |
color_mode="fixed", | |
default_size=4, | |
), | |
) | |
prompt = gr.Textbox(label="Prompt", value="", show_label=True) | |
with gr.Row(): | |
run_button = gr.Button("Run", scale=1) | |
randomize_seed = gr.Button("Random", scale=1, visible=False) | |
gr.Markdown( | |
""" | |
### Instructions | |
1. Enter a text prompt (e.g. cat). | |
2. Draw some sketches on the Sketchpad. | |
3. Click on the **Run** button to generate image and sketches in the Final Image and Sketch Outputs, respectively. | |
4. You may then select a frame by the Frame Selector and click on **Apply** to apply the selected frame to the Sketchpad. | |
5. You may then modify the sketches based on the applied frame and click on **Run** again to generate new images and sketches. | |
6. Repeat steps 4 and 5 to generate new images and sketches until you are satisfied with the result. | |
7. To restart from scratch, click on the **Bin Icon** on the top right corner of the Sketchpad. | |
**Thanks to the [paper](https://arxiv.org/abs/2403.12036) and their open-sourced models!** | |
""" | |
) | |
with gr.Column(scale=1): | |
frame_result = gr.Image( | |
height=HEIGHT, | |
width=WIDTH, | |
label="Sketch Outputs", | |
type="pil", | |
show_label=True, | |
show_download_button=True, | |
interactive=False, | |
visible=True, | |
) | |
apply_button = gr.Button("Apply", scale=1, visible=True, interactive=False) | |
frame_selector = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0, | |
step=1, | |
visible=True, | |
interactive=False, | |
scale=4, | |
label="Frame Selector", | |
) | |
with gr.Column(scale=1): | |
result = gr.Image( | |
height=HEIGHT, | |
width=WIDTH, | |
label="Final Image", | |
type="pil", | |
show_label=True, | |
show_download_button=True, | |
interactive=False, | |
visible=True, | |
) | |
# invisible elements | |
style = gr.Dropdown( | |
label="Style", | |
choices=STYLE_NAMES, | |
value=DEFAULT_STYLE_NAME, | |
scale=1, | |
visible=False, | |
) | |
prompt_temp = gr.Textbox( | |
label="Prompt Style Template", | |
value=styles[DEFAULT_STYLE_NAME], | |
max_lines=1, | |
scale=2, | |
visible=False, | |
) | |
val_r = gr.Slider( | |
label="Sketch guidance: ", | |
show_label=True, | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.01, | |
scale=4, | |
visible=False, | |
) | |
seed = gr.Textbox(label="Seed", value=42, scale=4, visible=False) | |
frames = gr.State([]) | |
sketches = gr.Image( | |
height=HEIGHT, | |
width=WIDTH, | |
show_label=False, | |
show_download_button=True, | |
type="pil", | |
visible=False, | |
) | |
one_frame = gr.Image( | |
height=HEIGHT, | |
width=WIDTH, | |
show_label=False, | |
show_download_button=True, | |
type="pil", | |
interactive=False, | |
visible=False, | |
) | |
inputs = [image, prompt, prompt_temp, style, seed, val_r] | |
outputs = [result] | |
randomize_seed_click = ( | |
randomize_seed.click( | |
lambda: random.randint(0, MAX_SEED), | |
inputs=[], | |
outputs=seed, | |
) | |
.then( | |
fn=make_button_and_slider_unclickable, | |
inputs=None, | |
outputs=[apply_button, frame_selector], | |
) | |
.then( | |
fn=run, | |
inputs=inputs, | |
outputs=outputs, | |
) | |
.then( | |
image_to_sketch_gif, | |
inputs=[result], | |
outputs=[frame_result, frames, frame_selector, apply_button], | |
) | |
.then( | |
fn=make_button_and_slider_clickable, | |
inputs=None, | |
outputs=[apply_button, frame_selector], | |
) | |
) | |
# prompt_submit = ( | |
# prompt.submit( | |
# make_button_and_slider_unclickable, | |
# inputs=None, | |
# outputs=[apply_button, frame_selector], | |
# ) | |
# .then(fn=run, inputs=inputs, outputs=outputs) | |
# .then( | |
# image_to_sketch_gif, | |
# inputs=[result], | |
# outputs=[frame_result, frames, frame_selector, apply_button], | |
# ) | |
# .then( | |
# fn=make_button_and_slider_clickable, | |
# inputs=None, | |
# outputs=[apply_button, frame_selector], | |
# ) | |
# ) | |
style_change = ( | |
style.change( | |
fn=make_button_and_slider_unclickable, | |
inputs=None, | |
outputs=[apply_button, frame_selector], | |
) | |
.then(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]) | |
.then( | |
fn=run, | |
inputs=inputs, | |
outputs=outputs, | |
) | |
.then( | |
image_to_sketch_gif, | |
inputs=[result], | |
outputs=[frame_result, frames, frame_selector, apply_button], | |
) | |
.then( | |
fn=make_button_and_slider_clickable, | |
inputs=None, | |
outputs=[apply_button, frame_selector], | |
) | |
) | |
val_r_change = ( | |
val_r.change( | |
fn=make_button_and_slider_unclickable, | |
inputs=None, | |
outputs=[apply_button, frame_selector], | |
) | |
.then(run, inputs=inputs, outputs=outputs) | |
.then( | |
image_to_sketch_gif, | |
inputs=[result], | |
outputs=[frame_result, frames, frame_selector, apply_button], | |
) | |
.then( | |
fn=make_button_and_slider_clickable, | |
inputs=None, | |
outputs=[apply_button, frame_selector], | |
) | |
) | |
run_button_click = ( | |
run_button.click( | |
fn=make_button_and_slider_unclickable, | |
inputs=None, | |
outputs=[apply_button, frame_selector], | |
) | |
.then(fn=run, inputs=inputs, outputs=outputs) | |
.then( | |
image_to_sketch_gif, | |
inputs=[result], | |
outputs=[frame_result, frames, frame_selector, apply_button], | |
) | |
.then( | |
fn=make_button_and_slider_clickable, | |
inputs=None, | |
outputs=[apply_button, frame_selector], | |
) | |
) | |
# image_apply = ( | |
# image.apply( | |
# fn=make_button_and_slider_unclickable, | |
# inputs=None, | |
# outputs=[apply_button, frame_selector], | |
# ) | |
# .then( | |
# run, | |
# inputs=inputs, | |
# outputs=outputs, | |
# ) | |
# .then( | |
# image_to_sketch_gif, | |
# inputs=[result], | |
# outputs=[frame_result, frames, frame_selector, apply_button], | |
# ) | |
# .then( | |
# fn=make_button_and_slider_clickable, | |
# inputs=None, | |
# outputs=[apply_button, frame_selector], | |
# ) | |
# ) | |
# | |
apply_button.click( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
cancels=[ | |
run_button_click, | |
randomize_seed_click, | |
style_change, | |
val_r_change, | |
], | |
) | |
apply_button.click( | |
fn=apply_func_click, | |
inputs=[frames, frame_selector], | |
outputs=[image], | |
) | |
frame_selector.change( | |
fn=frame_selector_change, | |
inputs=[frame_selector, frames], | |
outputs=[frame_result], | |
) | |
image.clear( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
cancels=[ | |
run_button_click, | |
randomize_seed_click, | |
style_change, | |
val_r_change, | |
], | |
) | |
image.clear( | |
fn=clear_image_editor, | |
inputs=None, | |
outputs=[ | |
image, | |
result, | |
frame_result, | |
frames, | |
frame_selector, | |
apply_button, | |
], | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() | |