File size: 3,340 Bytes
9d59b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Add, Concatenate, Multiply
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# Spatial Attention Layer
# Define SpatialAttention layer
class SpatialAttention(tf.keras.layers.Layer):
    def __init__(self, kernel_size=7, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.kernel_size = kernel_size
        self.conv = Conv2D(filters=1, kernel_size=kernel_size, padding='same', activation='sigmoid')

    def call(self, inputs):
        avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
        max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
        concat = Concatenate()([avg_pool, max_pool])
        attention = self.conv(concat)
        return Multiply()([inputs, attention])

    def get_config(self):
        config = super(SpatialAttention, self).get_config()
        config.update({'kernel_size': self.kernel_size})
        return config
        
# Build Autoencoder
def build_autoencoder(height, width,):
    input_img = Input(shape=(height, width, 1))
    
    # Encoder
    x = Conv2D(96, (3, 3), activation='relu', padding='same')(input_img)
    x = BatchNormalization()(x)
    x = SpatialAttention()(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    
    # Residual Block 1
    residual = Conv2D(192, (1, 1), padding='same')(x)
    x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x, residual])
    x = SpatialAttention()(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    
    # Residual Block 2
    residual = Conv2D(384, (1, 1), padding='same')(x)
    x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x, residual])
    x = SpatialAttention()(x)
    encoded = MaxPooling2D((2, 2), padding='same')(x)
    
    # Decoder
    x = Conv2D(384, (3, 3), activation='relu', padding='same')(encoded)
    x = BatchNormalization()(x)
    x = SpatialAttention()(x)
    x = UpSampling2D((2, 2))(x)
    
    # Residual Block 3
    residual = Conv2D(192, (1, 1), padding='same')(x)
    x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x, residual])
    x = SpatialAttention()(x)
    x = UpSampling2D((2, 2))(x)
    
    x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = SpatialAttention()(x)
    x = UpSampling2D((2, 2))(x)
    
    decoded = Conv2D(2, (3, 3), activation=None, padding='same')(x)
    
    return Model(input_img, decoded)





if __name__ == "__main__":
    # Define constants
    HEIGHT, WIDTH = 512, 512
    # Compile model
    autoencoder = build_autoencoder()
    autoencoder.summary()
    autoencoder.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())