Update app.py
Browse files
app.py
CHANGED
@@ -1,71 +1,71 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from PIL import Image
|
3 |
-
import os
|
4 |
-
import tensorflow as tf
|
5 |
-
import requests
|
6 |
-
|
7 |
-
from
|
8 |
-
|
9 |
-
# Load the saved model once at startup
|
10 |
-
load_model_path = "./ckpts/best_model.h5"
|
11 |
-
if not os.path.exists(load_model_path):
|
12 |
-
os.makedirs(os.path.dirname(load_model_path), exist_ok=True)
|
13 |
-
url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale-to-color-landscape/resolve/main/ckpts/best_model.h5"
|
14 |
-
print(f"Downloading model from {url}...")
|
15 |
-
with requests.get(url, stream=True) as r:
|
16 |
-
r.raise_for_status()
|
17 |
-
with open(load_model_path, "wb") as f:
|
18 |
-
for chunk in r.iter_content(chunk_size=8192):
|
19 |
-
f.write(chunk)
|
20 |
-
print("Download complete.")
|
21 |
-
|
22 |
-
print(f"Loading model from {load_model_path}...")
|
23 |
-
loaded_autoencoder = tf.keras.models.load_model(
|
24 |
-
load_model_path,
|
25 |
-
custom_objects={'SpatialAttention': SpatialAttention}
|
26 |
-
)
|
27 |
-
|
28 |
-
def process_image(input_img):
|
29 |
-
# Convert PIL Image to numpy array and normalize
|
30 |
-
img = input_img.convert("RGB")
|
31 |
-
img = img.resize((256, 256)) # adjust size as needed
|
32 |
-
img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
|
33 |
-
img_array = img_array[None, ...] # add batch dimension
|
34 |
-
|
35 |
-
# Run inference
|
36 |
-
output_array = loaded_autoencoder.predict(img_array)
|
37 |
-
output_img = tf.keras.preprocessing.image.array_to_img(output_array[0])
|
38 |
-
|
39 |
-
return output_img
|
40 |
-
|
41 |
-
custom_css = """
|
42 |
-
body {background: linear-gradient(135deg, #232526 0%, #414345 100%) !important;}
|
43 |
-
.gradio-container {background: transparent !important;}
|
44 |
-
h1, .gr-title {color: #00e6d3 !important; font-family: 'Segoe UI', sans-serif;}
|
45 |
-
.gr-description {color: #e0e0e0 !important; font-size: 1.1em;}
|
46 |
-
.gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.18);}
|
47 |
-
.gr-button {background: linear-gradient(90deg, #00e6d3 0%, #0072ff 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
|
48 |
-
"""
|
49 |
-
|
50 |
-
demo = gr.Interface(
|
51 |
-
fn=process_image,
|
52 |
-
inputs=gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L", shape=(256, 256)),
|
53 |
-
outputs=gr.Image(type="pil", label="Colorized Output"),
|
54 |
-
title="🌄 Gray2Color Landscape Autoencoder",
|
55 |
-
description=(
|
56 |
-
"<div style='font-size:1.15em;line-height:1.6em;'>"
|
57 |
-
"Transform your <b>grayscale landscape</b> photos into vivid color with a state-of-the-art autoencoder.<br>"
|
58 |
-
"Simply upload a grayscale image and see the magic happen!"
|
59 |
-
"</div>"
|
60 |
-
),
|
61 |
-
theme="soft",
|
62 |
-
css=custom_css,
|
63 |
-
allow_flagging="never",
|
64 |
-
examples=[
|
65 |
-
["examples/grayscale_landscape1.jpg"],
|
66 |
-
["examples/grayscale_landscape2.jpg"]
|
67 |
-
]
|
68 |
-
)
|
69 |
-
|
70 |
-
if __name__ == "__main__":
|
71 |
demo.launch()
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
import tensorflow as tf
|
5 |
+
import requests
|
6 |
+
|
7 |
+
from models.auto_encoder_gray2color import SpatialAttention
|
8 |
+
|
9 |
+
# Load the saved model once at startup
|
10 |
+
load_model_path = "./ckpts/best_model.h5"
|
11 |
+
if not os.path.exists(load_model_path):
|
12 |
+
os.makedirs(os.path.dirname(load_model_path), exist_ok=True)
|
13 |
+
url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale-to-color-landscape/resolve/main/ckpts/best_model.h5"
|
14 |
+
print(f"Downloading model from {url}...")
|
15 |
+
with requests.get(url, stream=True) as r:
|
16 |
+
r.raise_for_status()
|
17 |
+
with open(load_model_path, "wb") as f:
|
18 |
+
for chunk in r.iter_content(chunk_size=8192):
|
19 |
+
f.write(chunk)
|
20 |
+
print("Download complete.")
|
21 |
+
|
22 |
+
print(f"Loading model from {load_model_path}...")
|
23 |
+
loaded_autoencoder = tf.keras.models.load_model(
|
24 |
+
load_model_path,
|
25 |
+
custom_objects={'SpatialAttention': SpatialAttention}
|
26 |
+
)
|
27 |
+
|
28 |
+
def process_image(input_img):
|
29 |
+
# Convert PIL Image to numpy array and normalize
|
30 |
+
img = input_img.convert("RGB")
|
31 |
+
img = img.resize((256, 256)) # adjust size as needed
|
32 |
+
img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
|
33 |
+
img_array = img_array[None, ...] # add batch dimension
|
34 |
+
|
35 |
+
# Run inference
|
36 |
+
output_array = loaded_autoencoder.predict(img_array)
|
37 |
+
output_img = tf.keras.preprocessing.image.array_to_img(output_array[0])
|
38 |
+
|
39 |
+
return output_img
|
40 |
+
|
41 |
+
custom_css = """
|
42 |
+
body {background: linear-gradient(135deg, #232526 0%, #414345 100%) !important;}
|
43 |
+
.gradio-container {background: transparent !important;}
|
44 |
+
h1, .gr-title {color: #00e6d3 !important; font-family: 'Segoe UI', sans-serif;}
|
45 |
+
.gr-description {color: #e0e0e0 !important; font-size: 1.1em;}
|
46 |
+
.gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.18);}
|
47 |
+
.gr-button {background: linear-gradient(90deg, #00e6d3 0%, #0072ff 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
|
48 |
+
"""
|
49 |
+
|
50 |
+
demo = gr.Interface(
|
51 |
+
fn=process_image,
|
52 |
+
inputs=gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L", shape=(256, 256)),
|
53 |
+
outputs=gr.Image(type="pil", label="Colorized Output"),
|
54 |
+
title="🌄 Gray2Color Landscape Autoencoder",
|
55 |
+
description=(
|
56 |
+
"<div style='font-size:1.15em;line-height:1.6em;'>"
|
57 |
+
"Transform your <b>grayscale landscape</b> photos into vivid color with a state-of-the-art autoencoder.<br>"
|
58 |
+
"Simply upload a grayscale image and see the magic happen!"
|
59 |
+
"</div>"
|
60 |
+
),
|
61 |
+
theme="soft",
|
62 |
+
css=custom_css,
|
63 |
+
allow_flagging="never",
|
64 |
+
examples=[
|
65 |
+
["examples/grayscale_landscape1.jpg"],
|
66 |
+
["examples/grayscale_landscape2.jpg"]
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
demo.launch()
|