danhtran2mind's picture
Update app.py
b0e70fe verified
raw
history blame
3.51 kB
import gradio as gr
from PIL import Image
import os
import numpy as np
import tensorflow as tf
import requests
from models.auto_encoder_gray2color import SpatialAttention
WIDTH, HEIGHT = 512, 512
# Load the saved model once at startup
load_model_path = "./ckpts/best_model.h5"
if not os.path.exists(load_model_path):
os.makedirs(os.path.dirname(load_model_path), exist_ok=True)
url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale-to-color-landscape/resolve/main/ckpts/best_model.h5"
print(f"Downloading model from {url}...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(load_model_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Download complete.")
print(f"Loading model from {load_model_path}...")
loaded_autoencoder = tf.keras.models.load_model(
load_model_path,
custom_objects={'SpatialAttention': SpatialAttention}
)
def process_image(input_img):
# Convert PIL Image to grayscale and resize
img = input_img.convert("L") # Convert to grayscale (single channel)
img = img.resize((WIDTH, HEIGHT)) # Adjust size as needed
img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
img_array = img_array[None, ..., 0:1] # Add batch dimension and keep single channel
# print("img_array shape: ", img_array.shape)
# Run inference
output_array = loaded_autoencoder.predict(img_array)
print("output_array shape: ", output_array.shape)
# Assuming output_array has shape (1, 512, 512, 2) for U and V channels
# Extract Y (grayscale input) and UV (model output)
y_channel = img_array[0, :, :, 0] # Grayscale input (Y channel)
uv_channels = output_array[0] # Model output (U and V channels)
# Combine Y, U, V into a 3-channel YUV image
yuv_image = np.stack([y_channel, uv_channels[:, :, 0], uv_channels[:, :, 1]], axis=-1)
# Convert YUV to RGB
yuv_image = yuv_image * 255.0 # Denormalize
rgb_image = Image.fromarray(yuv_image.astype(np.uint8), mode="YCbCr") # Use YCbCr (alias for YUV in PIL)
rgb_image = rgb_image.convert("RGB") # Convert to RGB
return rgb_image
custom_css = """
body {background: linear-gradient(135deg, #232526 0%, #414345 100%) !important;}
.gradio-container {background: transparent !important;}
h1, .gr-title {color: #00e6d3 !important; font-family: 'Segoe UI', sans-serif;}
.gr-description {color: #e0e0e0 !important; font-size: 1.1em;}
.gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.18);}
.gr-button {background: linear-gradient(90deg, #00e6d3 0%, #0072ff 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
"""
demo = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L"),
outputs=gr.Image(type="pil", label="Colorized Output"),
title="πŸŒ„ Gray2Color Landscape Autoencoder",
description=(
"<div style='font-size:1.15em;line-height:1.6em;'>"
"Transform your <b>grayscale landscape</b> photos into vivid color with a state-of-the-art autoencoder.<br>"
"Simply upload a grayscale image and see the magic happen!"
"</div>"
),
theme="soft",
css=custom_css,
allow_flagging="never",
examples=[
["examples/example_1.jpg"],
["examples/example_2.jpg"]
]
)
if __name__ == "__main__":
demo.launch()