danhtran2mind commited on
Commit
9d59b6c
·
verified ·
1 Parent(s): 8a6c301

Upload 4 files

Browse files
ckpts/best_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6d0361fa140c1dc3b279bcce8107c28b6e10a4e1bc31f770e5b071a44f5f76d
3
+ size 20800096
models/auto_encoder_gray2color.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import tensorflow as tf
3
+ from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Add, Concatenate, Multiply
4
+ from tensorflow.keras.models import Model
5
+ from tensorflow.keras.optimizers import Adam
6
+
7
+ # Spatial Attention Layer
8
+ # Define SpatialAttention layer
9
+ class SpatialAttention(tf.keras.layers.Layer):
10
+ def __init__(self, kernel_size=7, **kwargs):
11
+ super(SpatialAttention, self).__init__(**kwargs)
12
+ self.kernel_size = kernel_size
13
+ self.conv = Conv2D(filters=1, kernel_size=kernel_size, padding='same', activation='sigmoid')
14
+
15
+ def call(self, inputs):
16
+ avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
17
+ max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
18
+ concat = Concatenate()([avg_pool, max_pool])
19
+ attention = self.conv(concat)
20
+ return Multiply()([inputs, attention])
21
+
22
+ def get_config(self):
23
+ config = super(SpatialAttention, self).get_config()
24
+ config.update({'kernel_size': self.kernel_size})
25
+ return config
26
+
27
+ # Build Autoencoder
28
+ def build_autoencoder(height, width,):
29
+ input_img = Input(shape=(height, width, 1))
30
+
31
+ # Encoder
32
+ x = Conv2D(96, (3, 3), activation='relu', padding='same')(input_img)
33
+ x = BatchNormalization()(x)
34
+ x = SpatialAttention()(x)
35
+ x = MaxPooling2D((2, 2), padding='same')(x)
36
+
37
+ # Residual Block 1
38
+ residual = Conv2D(192, (1, 1), padding='same')(x)
39
+ x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
40
+ x = BatchNormalization()(x)
41
+ x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
42
+ x = BatchNormalization()(x)
43
+ x = Add()([x, residual])
44
+ x = SpatialAttention()(x)
45
+ x = MaxPooling2D((2, 2), padding='same')(x)
46
+
47
+ # Residual Block 2
48
+ residual = Conv2D(384, (1, 1), padding='same')(x)
49
+ x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
50
+ x = BatchNormalization()(x)
51
+ x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
52
+ x = BatchNormalization()(x)
53
+ x = Add()([x, residual])
54
+ x = SpatialAttention()(x)
55
+ encoded = MaxPooling2D((2, 2), padding='same')(x)
56
+
57
+ # Decoder
58
+ x = Conv2D(384, (3, 3), activation='relu', padding='same')(encoded)
59
+ x = BatchNormalization()(x)
60
+ x = SpatialAttention()(x)
61
+ x = UpSampling2D((2, 2))(x)
62
+
63
+ # Residual Block 3
64
+ residual = Conv2D(192, (1, 1), padding='same')(x)
65
+ x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
66
+ x = BatchNormalization()(x)
67
+ x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
68
+ x = BatchNormalization()(x)
69
+ x = Add()([x, residual])
70
+ x = SpatialAttention()(x)
71
+ x = UpSampling2D((2, 2))(x)
72
+
73
+ x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
74
+ x = BatchNormalization()(x)
75
+ x = SpatialAttention()(x)
76
+ x = UpSampling2D((2, 2))(x)
77
+
78
+ decoded = Conv2D(2, (3, 3), activation=None, padding='same')(x)
79
+
80
+ return Model(input_img, decoded)
81
+
82
+
83
+
84
+
85
+
86
+ if __name__ == "__main__":
87
+ # Define constants
88
+ HEIGHT, WIDTH = 512, 512
89
+ # Compile model
90
+ autoencoder = build_autoencoder()
91
+ autoencoder.summary()
92
+ autoencoder.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
notebooks/autoencoder-grayscale-to-color-landscape.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ tensorflow==2.18.0
3
+ opencv-python==4.11.0.86
4
+ scikit-image==0.25.2
5
+ matplotlib==3.7.2