autoencoder-grayscale2color-landscape / models /auto_encoder_gray2color.py
danhtran2mind's picture
Upload 4 files
9d59b6c verified
raw
history blame
3.34 kB
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Add, Concatenate, Multiply
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
# Spatial Attention Layer
# Define SpatialAttention layer
class SpatialAttention(tf.keras.layers.Layer):
def __init__(self, kernel_size=7, **kwargs):
super(SpatialAttention, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.conv = Conv2D(filters=1, kernel_size=kernel_size, padding='same', activation='sigmoid')
def call(self, inputs):
avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
concat = Concatenate()([avg_pool, max_pool])
attention = self.conv(concat)
return Multiply()([inputs, attention])
def get_config(self):
config = super(SpatialAttention, self).get_config()
config.update({'kernel_size': self.kernel_size})
return config
# Build Autoencoder
def build_autoencoder(height, width,):
input_img = Input(shape=(height, width, 1))
# Encoder
x = Conv2D(96, (3, 3), activation='relu', padding='same')(input_img)
x = BatchNormalization()(x)
x = SpatialAttention()(x)
x = MaxPooling2D((2, 2), padding='same')(x)
# Residual Block 1
residual = Conv2D(192, (1, 1), padding='same')(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Add()([x, residual])
x = SpatialAttention()(x)
x = MaxPooling2D((2, 2), padding='same')(x)
# Residual Block 2
residual = Conv2D(384, (1, 1), padding='same')(x)
x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Add()([x, residual])
x = SpatialAttention()(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)
# Decoder
x = Conv2D(384, (3, 3), activation='relu', padding='same')(encoded)
x = BatchNormalization()(x)
x = SpatialAttention()(x)
x = UpSampling2D((2, 2))(x)
# Residual Block 3
residual = Conv2D(192, (1, 1), padding='same')(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Add()([x, residual])
x = SpatialAttention()(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = SpatialAttention()(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(2, (3, 3), activation=None, padding='same')(x)
return Model(input_img, decoded)
if __name__ == "__main__":
# Define constants
HEIGHT, WIDTH = 512, 512
# Compile model
autoencoder = build_autoencoder()
autoencoder.summary()
autoencoder.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())