S2IA / app.py
zzhao-swansea's picture
Update app.py
cba80be verified
import base64
import os
import pdb
import random
import sys
import time
from io import BytesIO
import gradio as gr
import numpy as np
import spaces
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from torchvision import transforms
from src.img2skt import image_to_sketch_gif
from src.model import make_1step_sched
from src.pix2pix_turbo import Pix2Pix_Turbo
model = Pix2Pix_Turbo("sketch_to_image_stochastic")
style_list = [
{
"name": "No Style",
"prompt": "{prompt}",
},
{
"name": "Cinematic",
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
},
{
"name": "3D Model",
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
},
{
"name": "Anime",
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
},
{
"name": "Digital Art",
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
},
{
"name": "Photographic",
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
},
{
"name": "Pixel art",
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
},
{
"name": "Fantasy art",
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
},
{
"name": "Neonpunk",
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
},
{
"name": "Manga",
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
},
]
styles = {k["name"]: k["prompt"] for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Manga"
MAX_SEED = np.iinfo(np.int32).max
HEIGHT = 512 # Display height
WIDTH = 512 # Display width
PROC_WIDTH = 512 # Processing width
PROC_HEIGHT = 512 # Processing height
ITER_DELAY = 1.0
# Create a white background image
def create_white_background(width, height):
return Image.new("RGB", (width, height), color="white")
white_background = create_white_background(WIDTH, HEIGHT)
def make_button_and_slider_unclickable():
# Disable the button and slider
return (
gr.Button(interactive=False),
gr.Slider(
interactive=False,
),
)
def make_button_and_slider_clickable():
# Enable the button and slider
return (
gr.Button(interactive=True),
gr.Slider(
interactive=True,
),
)
@spaces.GPU(duration=45)
def run(image, prompt, prompt_template, style_name, seed, val_r):
image = image["composite"]
if image.size != (PROC_WIDTH, PROC_HEIGHT):
image = image.resize((PROC_WIDTH, PROC_HEIGHT))
prompt = prompt_template.replace("{prompt}", prompt)
image = image.convert("RGB")
image = Image.fromarray(255 - np.array(image))
image_t = TF.to_tensor(image) > 0.5
with torch.no_grad():
c_t = image_t.unsqueeze(0).cuda().float()
torch.manual_seed(seed)
B, C, H, W = c_t.shape
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
output_pil = TF.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
if output_pil.size != (WIDTH, HEIGHT):
output_pil = output_pil.resize((WIDTH, HEIGHT))
return output_pil
def clear_image_editor():
return (
{"background": white_background, "layers": None, "composite": None},
gr.Image(
value=None,
),
gr.Image(
value=None,
),
gr.State([]),
gr.Slider(
maximum=1,
value=0,
interactive=False,
),
gr.Button(interactive=False),
)
def apply_func_click(frames, frame_selector):
# Apply the selected frame to the sketchpad
try:
selected_frame = frames[int(frame_selector)]
return {
"background": white_background,
"layers": [selected_frame],
"composite": None,
}
except Exception as e:
pass
def frame_selector_change(frame_idx, frames):
try:
frame_idx = int(frame_idx)
frame = frames[frame_idx]
return frame
except Exception as e:
pass
with gr.Blocks() as demo:
gr.Markdown("# Sketch to Image Demo Augmented with Sketch Generation")
with gr.Row():
with gr.Column(scale=1):
image = gr.Sketchpad(
value={
"background": white_background,
"layers": None,
"composite": white_background,
},
image_mode="L",
type="pil",
sources=None,
# container=True,
label="Sketchpad",
show_label=True,
show_download_button=True,
# show_share_button=True,
interactive=True,
layers=False,
# height="80vw",
canvas_size=(WIDTH, HEIGHT),
show_fullscreen_button=False,
brush=gr.Brush(
colors=["#000000"],
color_mode="fixed",
default_size=4,
),
)
prompt = gr.Textbox(label="Prompt", value="", show_label=True)
with gr.Row():
run_button = gr.Button("Run", scale=1)
randomize_seed = gr.Button("Random", scale=1, visible=False)
gr.Markdown(
"""
### Instructions
1. Enter a text prompt (e.g. cat).
2. Draw some sketches on the Sketchpad.
3. Click on the **Run** button to generate image and sketches in the Final Image and Sketch Outputs, respectively.
4. You may then select a frame by the Frame Selector and click on **Apply** to apply the selected frame to the Sketchpad.
5. You may then modify the sketches based on the applied frame and click on **Run** again to generate new images and sketches.
6. Repeat steps 4 and 5 to generate new images and sketches until you are satisfied with the result.
7. To restart from scratch, click on the **Bin Icon** on the top right corner of the Sketchpad.
**Thanks to the [paper](https://arxiv.org/abs/2403.12036) and their open-sourced models!**
"""
)
with gr.Column(scale=1):
frame_result = gr.Image(
height=HEIGHT,
width=WIDTH,
label="Sketch Outputs",
type="pil",
show_label=True,
show_download_button=True,
interactive=False,
visible=True,
)
apply_button = gr.Button("Apply", scale=1, visible=True, interactive=False)
frame_selector = gr.Slider(
minimum=0,
maximum=1,
value=0,
step=1,
visible=True,
interactive=False,
scale=4,
label="Frame Selector",
)
with gr.Column(scale=1):
result = gr.Image(
height=HEIGHT,
width=WIDTH,
label="Final Image",
type="pil",
show_label=True,
show_download_button=True,
interactive=False,
visible=True,
)
# invisible elements
style = gr.Dropdown(
label="Style",
choices=STYLE_NAMES,
value=DEFAULT_STYLE_NAME,
scale=1,
visible=False,
)
prompt_temp = gr.Textbox(
label="Prompt Style Template",
value=styles[DEFAULT_STYLE_NAME],
max_lines=1,
scale=2,
visible=False,
)
val_r = gr.Slider(
label="Sketch guidance: ",
show_label=True,
minimum=0,
maximum=1,
value=0.5,
step=0.01,
scale=4,
visible=False,
)
seed = gr.Textbox(label="Seed", value=42, scale=4, visible=False)
frames = gr.State([])
sketches = gr.Image(
height=HEIGHT,
width=WIDTH,
show_label=False,
show_download_button=True,
type="pil",
visible=False,
)
one_frame = gr.Image(
height=HEIGHT,
width=WIDTH,
show_label=False,
show_download_button=True,
type="pil",
interactive=False,
visible=False,
)
inputs = [image, prompt, prompt_temp, style, seed, val_r]
outputs = [result]
randomize_seed_click = (
randomize_seed.click(
lambda: random.randint(0, MAX_SEED),
inputs=[],
outputs=seed,
)
.then(
fn=make_button_and_slider_unclickable,
inputs=None,
outputs=[apply_button, frame_selector],
)
.then(
fn=run,
inputs=inputs,
outputs=outputs,
)
.then(
image_to_sketch_gif,
inputs=[result],
outputs=[frame_result, frames, frame_selector, apply_button],
)
.then(
fn=make_button_and_slider_clickable,
inputs=None,
outputs=[apply_button, frame_selector],
)
)
# prompt_submit = (
# prompt.submit(
# make_button_and_slider_unclickable,
# inputs=None,
# outputs=[apply_button, frame_selector],
# )
# .then(fn=run, inputs=inputs, outputs=outputs)
# .then(
# image_to_sketch_gif,
# inputs=[result],
# outputs=[frame_result, frames, frame_selector, apply_button],
# )
# .then(
# fn=make_button_and_slider_clickable,
# inputs=None,
# outputs=[apply_button, frame_selector],
# )
# )
style_change = (
style.change(
fn=make_button_and_slider_unclickable,
inputs=None,
outputs=[apply_button, frame_selector],
)
.then(lambda x: styles[x], inputs=[style], outputs=[prompt_temp])
.then(
fn=run,
inputs=inputs,
outputs=outputs,
)
.then(
image_to_sketch_gif,
inputs=[result],
outputs=[frame_result, frames, frame_selector, apply_button],
)
.then(
fn=make_button_and_slider_clickable,
inputs=None,
outputs=[apply_button, frame_selector],
)
)
val_r_change = (
val_r.change(
fn=make_button_and_slider_unclickable,
inputs=None,
outputs=[apply_button, frame_selector],
)
.then(run, inputs=inputs, outputs=outputs)
.then(
image_to_sketch_gif,
inputs=[result],
outputs=[frame_result, frames, frame_selector, apply_button],
)
.then(
fn=make_button_and_slider_clickable,
inputs=None,
outputs=[apply_button, frame_selector],
)
)
run_button_click = (
run_button.click(
fn=make_button_and_slider_unclickable,
inputs=None,
outputs=[apply_button, frame_selector],
)
.then(fn=run, inputs=inputs, outputs=outputs)
.then(
image_to_sketch_gif,
inputs=[result],
outputs=[frame_result, frames, frame_selector, apply_button],
)
.then(
fn=make_button_and_slider_clickable,
inputs=None,
outputs=[apply_button, frame_selector],
)
)
# image_apply = (
# image.apply(
# fn=make_button_and_slider_unclickable,
# inputs=None,
# outputs=[apply_button, frame_selector],
# )
# .then(
# run,
# inputs=inputs,
# outputs=outputs,
# )
# .then(
# image_to_sketch_gif,
# inputs=[result],
# outputs=[frame_result, frames, frame_selector, apply_button],
# )
# .then(
# fn=make_button_and_slider_clickable,
# inputs=None,
# outputs=[apply_button, frame_selector],
# )
# )
#
apply_button.click(
fn=None,
inputs=None,
outputs=None,
cancels=[
run_button_click,
randomize_seed_click,
style_change,
val_r_change,
],
)
apply_button.click(
fn=apply_func_click,
inputs=[frames, frame_selector],
outputs=[image],
)
frame_selector.change(
fn=frame_selector_change,
inputs=[frame_selector, frames],
outputs=[frame_result],
)
image.clear(
fn=None,
inputs=None,
outputs=None,
cancels=[
run_button_click,
randomize_seed_click,
style_change,
val_r_change,
],
)
image.clear(
fn=clear_image_editor,
inputs=None,
outputs=[
image,
result,
frame_result,
frames,
frame_selector,
apply_button,
],
)
if __name__ == "__main__":
demo.queue().launch()