Chain-of-Zoom / app.py
alexnasa's picture
Update app.py
ba815e8 verified
raw
history blame
7.31 kB
import gradio as gr
import subprocess
import os
import shutil
from pathlib import Path
from inference_coz_single import recursive_multiscale_sr
from PIL import Image, ImageDraw
import spaces
# ------------------------------------------------------------------
# CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
# ------------------------------------------------------------------
INPUT_DIR = "samples"
OUTPUT_DIR = "inference_results/coz_vlmprompt"
# ------------------------------------------------------------------
# HELPER: Resize & center-crop to 512, preserving aspect ratio
# ------------------------------------------------------------------
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
"""
Resize the input PIL image so that its shorter side == `size`,
then center-crop to exactly (size x size).
"""
w, h = img.size
scale = size / min(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.LANCZOS)
left = (new_w - size) // 2
top = (new_h - size) // 2
return img.crop((left, top, left + size, top + size))
# ------------------------------------------------------------------
# HELPER: Draw four concentric, centered rectangles on a 512×512 image
# ------------------------------------------------------------------
def make_preview_with_boxes(image_path: str, scale_option: str) -> Image.Image:
"""
1) Open the uploaded image from disk.
2) Resize & center-crop it to exactly 512×512.
3) Depending on scale_option ("1x","2x","4x"), compute four rectangle sizes:
- "1x": [512, 512, 512, 512]
- "2x": [256, 128, 64, 32]
- "4x": [128, 64, 32, 16]
4) Draw each of those four rectangles (outline only), all centered.
5) Return the modified PIL image.
"""
try:
orig = Image.open(image_path).convert("RGB")
except Exception as e:
# If something fails, return a plain 512×512 gray image as fallback
fallback = Image.new("RGB", (512, 512), (200, 200, 200))
draw = ImageDraw.Draw(fallback)
draw.text((20, 20), f"Error:\n{e}", fill="red")
return fallback
# 1. Resize & center-crop to 512×512
base = resize_and_center_crop(orig, 512) # now `base.size == (512,512)`
# 2. Determine the four box sizes
scale_int = int(scale_option.replace("x", "")) # e.g. "2x" -> 2
if scale_int == 1:
sizes = [512, 512, 512, 512]
else:
# For scale=2: sizes = [512//2, 512//(2*2), 512//(2*4), 512//(2*8)] -> [256,128,64,32]
# For scale=4: sizes = [512//4, 512//(4*2), 512//(4*4), 512//(4*8)] -> [128,64,32,16]
sizes = [512 // (scale_int * (2 ** i)) for i in range(4)]
draw = ImageDraw.Draw(base)
# 3. Outline color cycle (you can change these or use just one color)
colors = ["red", "lime", "cyan", "yellow"]
width = 3 # thickness of each rectangle’s outline
for idx, s in enumerate(sizes):
# Compute top-left corner so that box is centered in 512×512
x0 = (512 - s) // 2
y0 = (512 - s) // 2
x1 = x0 + s
y1 = y0 + s
draw.rectangle([(x0, y0), (x1, y1)], outline=colors[idx % len(colors)], width=width)
return base
@spaces.GPU(duration=120)
def run_with_upload(uploaded_image_path, upscale_option):
upscale_value = upscale_option.replace("x", "") # e.g. "2x" → "2"
return recursive_multiscale_sr(uploaded_image_path, int(upscale_value))[0]
# ------------------------------------------------------------------
# BUILD THE GRADIO INTERFACE (with updated callbacks)
# ------------------------------------------------------------------
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center;">
<h1>Chain-of-Zoom</h1>
<p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/bryanswkim/Chain-of-Zoom">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
</div>
"""
)
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
# 1) Image upload component
upload_image = gr.Image(
label="Upload your input image",
type="filepath"
)
# 2) Radio for choosing 1× / 2× / 4× upscaling
upscale_radio = gr.Radio(
choices=["1x", "2x", "4x"],
value="2x",
show_label=False
)
# 3) Button to launch inference
run_button = gr.Button("Chain-of-Zoom it")
# 4) Show the 512×512 preview with four centered rectangles
preview_with_box = gr.Image(
label="Preview (512×512 with centered boxes)",
type="pil", # we’ll return a PIL.Image from our function
interactive=False
)
with gr.Column():
# 5) Gallery to display multiple output images
output_gallery = gr.Gallery(
label="Inference Results",
show_label=True,
elem_id="gallery",
columns=[2], rows=[2]
)
# ------------------------------------------------------------------
# CALLBACK #1: Whenever the user uploads or changes the radio, update preview
# ------------------------------------------------------------------
def update_preview(img_path, scale_opt):
"""
If there's no image uploaded yet, return None (Gradio will show blank).
Otherwise, draw the resized 512×512 + four boxes and return it.
"""
if img_path is None:
return None
return make_preview_with_boxes(img_path, scale_opt)
# When the user uploads a new file:
upload_image.change(
fn=update_preview,
inputs=[upload_image, upscale_radio],
outputs=[preview_with_box]
)
# Also trigger preview redraw if they switch 1×/2×/4× after uploading:
upscale_radio.change(
fn=update_preview,
inputs=[upload_image, upscale_radio],
outputs=[preview_with_box]
)
# ------------------------------------------------------------------
# CALLBACK #2: When “Chain-of-Zoom it” is clicked, run inference
# ------------------------------------------------------------------
run_button.click(
fn=run_with_upload,
inputs=[upload_image, upscale_radio],
outputs=[output_gallery]
)
# ------------------------------------------------------------------
# START THE GRADIO SERVER
# ------------------------------------------------------------------
# 🔧 2) launch as usual
demo.launch(share=True)