Upload 6 files
Browse files- app.py +71 -0
- ckpts/best_model.h5 +3 -0
- dataset/README.md +42 -0
- models/auto_encoder_gray2color.py +92 -0
- notebooks/autoencoder-grayscale-to-color-landscape.ipynb +0 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
import tensorflow as tf
|
5 |
+
import requests
|
6 |
+
|
7 |
+
from .models.auto_encoder_gray2color import SpatialAttention
|
8 |
+
|
9 |
+
# Load the saved model once at startup
|
10 |
+
load_model_path = "./ckpts/best_model.h5"
|
11 |
+
if not os.path.exists(load_model_path):
|
12 |
+
os.makedirs(os.path.dirname(load_model_path), exist_ok=True)
|
13 |
+
url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale-to-color-landscape/resolve/main/ckpts/best_model.h5"
|
14 |
+
print(f"Downloading model from {url}...")
|
15 |
+
with requests.get(url, stream=True) as r:
|
16 |
+
r.raise_for_status()
|
17 |
+
with open(load_model_path, "wb") as f:
|
18 |
+
for chunk in r.iter_content(chunk_size=8192):
|
19 |
+
f.write(chunk)
|
20 |
+
print("Download complete.")
|
21 |
+
|
22 |
+
print(f"Loading model from {load_model_path}...")
|
23 |
+
loaded_autoencoder = tf.keras.models.load_model(
|
24 |
+
load_model_path,
|
25 |
+
custom_objects={'SpatialAttention': SpatialAttention}
|
26 |
+
)
|
27 |
+
|
28 |
+
def process_image(input_img):
|
29 |
+
# Convert PIL Image to numpy array and normalize
|
30 |
+
img = input_img.convert("RGB")
|
31 |
+
img = img.resize((256, 256)) # adjust size as needed
|
32 |
+
img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
|
33 |
+
img_array = img_array[None, ...] # add batch dimension
|
34 |
+
|
35 |
+
# Run inference
|
36 |
+
output_array = loaded_autoencoder.predict(img_array)
|
37 |
+
output_img = tf.keras.preprocessing.image.array_to_img(output_array[0])
|
38 |
+
|
39 |
+
return output_img
|
40 |
+
|
41 |
+
custom_css = """
|
42 |
+
body {background: linear-gradient(135deg, #232526 0%, #414345 100%) !important;}
|
43 |
+
.gradio-container {background: transparent !important;}
|
44 |
+
h1, .gr-title {color: #00e6d3 !important; font-family: 'Segoe UI', sans-serif;}
|
45 |
+
.gr-description {color: #e0e0e0 !important; font-size: 1.1em;}
|
46 |
+
.gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.18);}
|
47 |
+
.gr-button {background: linear-gradient(90deg, #00e6d3 0%, #0072ff 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
|
48 |
+
"""
|
49 |
+
|
50 |
+
demo = gr.Interface(
|
51 |
+
fn=process_image,
|
52 |
+
inputs=gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L", shape=(256, 256)),
|
53 |
+
outputs=gr.Image(type="pil", label="Colorized Output"),
|
54 |
+
title="🌄 Gray2Color Landscape Autoencoder",
|
55 |
+
description=(
|
56 |
+
"<div style='font-size:1.15em;line-height:1.6em;'>"
|
57 |
+
"Transform your <b>grayscale landscape</b> photos into vivid color with a state-of-the-art autoencoder.<br>"
|
58 |
+
"Simply upload a grayscale image and see the magic happen!"
|
59 |
+
"</div>"
|
60 |
+
),
|
61 |
+
theme="soft",
|
62 |
+
css=custom_css,
|
63 |
+
allow_flagging="never",
|
64 |
+
examples=[
|
65 |
+
["examples/grayscale_landscape1.jpg"],
|
66 |
+
["examples/grayscale_landscape2.jpg"]
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
demo.launch()
|
ckpts/best_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6d0361fa140c1dc3b279bcce8107c28b6e10a4e1bc31f770e5b071a44f5f76d
|
3 |
+
size 20800096
|
dataset/README.md
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```markdown
|
2 |
+
# Landscape Pictures Dataset Processing
|
3 |
+
|
4 |
+
[](https://www.kaggle.com/datasets/arnaud58/landscape-pictures)
|
5 |
+
|
6 |
+
This README provides instructions for downloading, extracting, and processing the landscape pictures dataset from Kaggle.
|
7 |
+
|
8 |
+
## Dataset Source
|
9 |
+
|
10 |
+
The dataset is sourced from Kaggle: Landscape Pictures by Arnaud58. Follow this link: [Kaggle Dataset](https://www.kaggle.com/datasets/arnaud58/landscape-pictures)
|
11 |
+
|
12 |
+
## Setup
|
13 |
+
|
14 |
+
1. **Create a Dataset Directory**: Create a directory to store the dataset:
|
15 |
+
|
16 |
+
```python
|
17 |
+
import os
|
18 |
+
|
19 |
+
ds_path = "./dataset/landscape-pictures"
|
20 |
+
os.makedirs(ds_path, exist_ok=True)
|
21 |
+
```
|
22 |
+
|
23 |
+
2. **Download the Dataset**: Use the following command to download the dataset from Kaggle:
|
24 |
+
|
25 |
+
```bash
|
26 |
+
curl -L https://www.kaggle.com/api/v1/datasets/download/arnaud58/landscape-pictures -o ./dataset/landscape-pictures.zip
|
27 |
+
```
|
28 |
+
|
29 |
+
Note: You may need a Kaggle API token for authentication. Ensure you have the `kaggle.json` file configured in `~/.kaggle/` or set up the Kaggle API as per Kaggle's API documentation.
|
30 |
+
|
31 |
+
3. **Extract the Dataset**: Run the following Python code to extract the downloaded zip file:
|
32 |
+
|
33 |
+
```python
|
34 |
+
import zipfile
|
35 |
+
import os
|
36 |
+
|
37 |
+
with zipfile.ZipFile('dataset/landscape-pictures.zip', 'r') as zip_ref:
|
38 |
+
zip_ref.extractall(ds_path)
|
39 |
+
```
|
40 |
+
|
41 |
+
This will extract the dataset into the `./dataset` directory.
|
42 |
+
```
|
models/auto_encoder_gray2color.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Add, Concatenate, Multiply
|
4 |
+
from tensorflow.keras.models import Model
|
5 |
+
from tensorflow.keras.optimizers import Adam
|
6 |
+
|
7 |
+
# Spatial Attention Layer
|
8 |
+
# Define SpatialAttention layer
|
9 |
+
class SpatialAttention(tf.keras.layers.Layer):
|
10 |
+
def __init__(self, kernel_size=7, **kwargs):
|
11 |
+
super(SpatialAttention, self).__init__(**kwargs)
|
12 |
+
self.kernel_size = kernel_size
|
13 |
+
self.conv = Conv2D(filters=1, kernel_size=kernel_size, padding='same', activation='sigmoid')
|
14 |
+
|
15 |
+
def call(self, inputs):
|
16 |
+
avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
|
17 |
+
max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
|
18 |
+
concat = Concatenate()([avg_pool, max_pool])
|
19 |
+
attention = self.conv(concat)
|
20 |
+
return Multiply()([inputs, attention])
|
21 |
+
|
22 |
+
def get_config(self):
|
23 |
+
config = super(SpatialAttention, self).get_config()
|
24 |
+
config.update({'kernel_size': self.kernel_size})
|
25 |
+
return config
|
26 |
+
|
27 |
+
# Build Autoencoder
|
28 |
+
def build_autoencoder(height, width,):
|
29 |
+
input_img = Input(shape=(height, width, 1))
|
30 |
+
|
31 |
+
# Encoder
|
32 |
+
x = Conv2D(96, (3, 3), activation='relu', padding='same')(input_img)
|
33 |
+
x = BatchNormalization()(x)
|
34 |
+
x = SpatialAttention()(x)
|
35 |
+
x = MaxPooling2D((2, 2), padding='same')(x)
|
36 |
+
|
37 |
+
# Residual Block 1
|
38 |
+
residual = Conv2D(192, (1, 1), padding='same')(x)
|
39 |
+
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
|
40 |
+
x = BatchNormalization()(x)
|
41 |
+
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
|
42 |
+
x = BatchNormalization()(x)
|
43 |
+
x = Add()([x, residual])
|
44 |
+
x = SpatialAttention()(x)
|
45 |
+
x = MaxPooling2D((2, 2), padding='same')(x)
|
46 |
+
|
47 |
+
# Residual Block 2
|
48 |
+
residual = Conv2D(384, (1, 1), padding='same')(x)
|
49 |
+
x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
|
50 |
+
x = BatchNormalization()(x)
|
51 |
+
x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
|
52 |
+
x = BatchNormalization()(x)
|
53 |
+
x = Add()([x, residual])
|
54 |
+
x = SpatialAttention()(x)
|
55 |
+
encoded = MaxPooling2D((2, 2), padding='same')(x)
|
56 |
+
|
57 |
+
# Decoder
|
58 |
+
x = Conv2D(384, (3, 3), activation='relu', padding='same')(encoded)
|
59 |
+
x = BatchNormalization()(x)
|
60 |
+
x = SpatialAttention()(x)
|
61 |
+
x = UpSampling2D((2, 2))(x)
|
62 |
+
|
63 |
+
# Residual Block 3
|
64 |
+
residual = Conv2D(192, (1, 1), padding='same')(x)
|
65 |
+
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
|
66 |
+
x = BatchNormalization()(x)
|
67 |
+
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
|
68 |
+
x = BatchNormalization()(x)
|
69 |
+
x = Add()([x, residual])
|
70 |
+
x = SpatialAttention()(x)
|
71 |
+
x = UpSampling2D((2, 2))(x)
|
72 |
+
|
73 |
+
x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
|
74 |
+
x = BatchNormalization()(x)
|
75 |
+
x = SpatialAttention()(x)
|
76 |
+
x = UpSampling2D((2, 2))(x)
|
77 |
+
|
78 |
+
decoded = Conv2D(2, (3, 3), activation=None, padding='same')(x)
|
79 |
+
|
80 |
+
return Model(input_img, decoded)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
# Define constants
|
88 |
+
HEIGHT, WIDTH = 512, 512
|
89 |
+
# Compile model
|
90 |
+
autoencoder = build_autoencoder()
|
91 |
+
autoencoder.summary()
|
92 |
+
autoencoder.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|
notebooks/autoencoder-grayscale-to-color-landscape.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.4
|
2 |
+
tensorflow==2.18.0
|
3 |
+
opencv-python==4.11.0.86
|
4 |
+
scikit-image==0.25.2
|
5 |
+
matplotlib==3.7.2
|