danhtran2mind commited on
Commit
3ebaa82
·
verified ·
1 Parent(s): 0a220d7

Update models/autoencoder_gray2color.py

Browse files
Files changed (1) hide show
  1. models/autoencoder_gray2color.py +86 -91
models/autoencoder_gray2color.py CHANGED
@@ -1,92 +1,87 @@
1
-
2
- import tensorflow as tf
3
- from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Add, Concatenate, Multiply
4
- from tensorflow.keras.models import Model
5
- from tensorflow.keras.optimizers import Adam
6
-
7
- # Spatial Attention Layer
8
- # Define SpatialAttention layer
9
- class SpatialAttention(tf.keras.layers.Layer):
10
- def __init__(self, kernel_size=7, **kwargs):
11
- super(SpatialAttention, self).__init__(**kwargs)
12
- self.kernel_size = kernel_size
13
- self.conv = Conv2D(filters=1, kernel_size=kernel_size, padding='same', activation='sigmoid')
14
-
15
- def call(self, inputs):
16
- avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
17
- max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
18
- concat = Concatenate()([avg_pool, max_pool])
19
- attention = self.conv(concat)
20
- return Multiply()([inputs, attention])
21
-
22
- def get_config(self):
23
- config = super(SpatialAttention, self).get_config()
24
- config.update({'kernel_size': self.kernel_size})
25
- return config
26
-
27
- # Build Autoencoder
28
- def build_autoencoder(height, width,):
29
- input_img = Input(shape=(height, width, 1))
30
-
31
- # Encoder
32
- x = Conv2D(96, (3, 3), activation='relu', padding='same')(input_img)
33
- x = BatchNormalization()(x)
34
- x = SpatialAttention()(x)
35
- x = MaxPooling2D((2, 2), padding='same')(x)
36
-
37
- # Residual Block 1
38
- residual = Conv2D(192, (1, 1), padding='same')(x)
39
- x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
40
- x = BatchNormalization()(x)
41
- x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
42
- x = BatchNormalization()(x)
43
- x = Add()([x, residual])
44
- x = SpatialAttention()(x)
45
- x = MaxPooling2D((2, 2), padding='same')(x)
46
-
47
- # Residual Block 2
48
- residual = Conv2D(384, (1, 1), padding='same')(x)
49
- x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
50
- x = BatchNormalization()(x)
51
- x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
52
- x = BatchNormalization()(x)
53
- x = Add()([x, residual])
54
- x = SpatialAttention()(x)
55
- encoded = MaxPooling2D((2, 2), padding='same')(x)
56
-
57
- # Decoder
58
- x = Conv2D(384, (3, 3), activation='relu', padding='same')(encoded)
59
- x = BatchNormalization()(x)
60
- x = SpatialAttention()(x)
61
- x = UpSampling2D((2, 2))(x)
62
-
63
- # Residual Block 3
64
- residual = Conv2D(192, (1, 1), padding='same')(x)
65
- x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
66
- x = BatchNormalization()(x)
67
- x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
68
- x = BatchNormalization()(x)
69
- x = Add()([x, residual])
70
- x = SpatialAttention()(x)
71
- x = UpSampling2D((2, 2))(x)
72
-
73
- x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
74
- x = BatchNormalization()(x)
75
- x = SpatialAttention()(x)
76
- x = UpSampling2D((2, 2))(x)
77
-
78
- decoded = Conv2D(2, (3, 3), activation=None, padding='same')(x)
79
-
80
- return Model(input_img, decoded)
81
-
82
-
83
-
84
-
85
-
86
- if __name__ == "__main__":
87
- # Define constants
88
- HEIGHT, WIDTH = 512, 512
89
- # Compile model
90
- autoencoder = build_autoencoder()
91
- autoencoder.summary()
92
  autoencoder.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Add, Concatenate, Multiply
3
+ from tensorflow.keras.models import Model
4
+ from tensorflow.keras.optimizers import Adam
5
+
6
+ # Set float32 policy
7
+ tf.keras.mixed_precision.set_global_policy('float32')
8
+
9
+ # Spatial Attention Layer
10
+ class SpatialAttention(tf.keras.layers.Layer):
11
+ def __init__(self, kernel_size=7, **kwargs):
12
+ super(SpatialAttention, self).__init__(**kwargs)
13
+ self.kernel_size = kernel_size
14
+ self.conv = Conv2D(filters=1, kernel_size=kernel_size, padding='same', activation='sigmoid')
15
+
16
+ def call(self, inputs):
17
+ avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
18
+ max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
19
+ concat = Concatenate()([avg_pool, max_pool])
20
+ attention = self.conv(concat)
21
+ return Multiply()([inputs, attention])
22
+
23
+ def get_config(self):
24
+ config = super(SpatialAttention, self).get_config()
25
+ config.update({'kernel_size': self.kernel_size})
26
+ return config
27
+
28
+ # Build Autoencoder
29
+ def build_autoencoder(height, width):
30
+ input_img = Input(shape=(height, width, 1))
31
+ # Encoder
32
+ x = Conv2D(96, (3, 3), activation='relu', padding='same')(input_img)
33
+ x = BatchNormalization()(x)
34
+ x = SpatialAttention()(x)
35
+ x = MaxPooling2D((2, 2), padding='same')(x)
36
+
37
+ # Residual Block 1
38
+ residual = Conv2D(192, (1, 1), padding='same')(x)
39
+ x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
40
+ x = BatchNormalization()(x)
41
+ x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
42
+ x = BatchNormalization()(x)
43
+ x = Add()([x, residual])
44
+ x = SpatialAttention()(x)
45
+ x = MaxPooling2D((2, 2), padding='same')(x)
46
+
47
+ # Residual Block 2
48
+ residual = Conv2D(384, (1, 1), padding='same')(x)
49
+ x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
50
+ x = BatchNormalization()(x)
51
+ x = Conv2D(384, (3, 3), activation='relu', padding='same')(x)
52
+ x = BatchNormalization()(x)
53
+ x = Add()([x, residual])
54
+ x = SpatialAttention()(x)
55
+ encoded = MaxPooling2D((2, 2), padding='same')(x)
56
+
57
+ # Decoder
58
+ x = Conv2D(384, (3, 3), activation='relu', padding='same')(encoded)
59
+ x = BatchNormalization()(x)
60
+ x = SpatialAttention()(x)
61
+ x = UpSampling2D((2, 2))(x)
62
+
63
+ # Residual Block 3
64
+ residual = Conv2D(192, (1, 1), padding='same')(x)
65
+ x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
66
+ x = BatchNormalization()(x)
67
+ x = Conv2D(192, (3, 3), activation='relu', padding='same')(x)
68
+ x = BatchNormalization()(x)
69
+ x = Add()([x, residual])
70
+ x = SpatialAttention()(x)
71
+ x = UpSampling2D((2, 2))(x)
72
+
73
+ x = Conv2D(96, (3, 3), activation='relu', padding='same')(x)
74
+ x = BatchNormalization()(x)
75
+ x = SpatialAttention()(x)
76
+ x = UpSampling2D((2, 2))(x)
77
+ decoded = Conv2D(2, (3, 3), activation=None, padding='same')(x)
78
+
79
+ return Model(input_img, decoded)
80
+
81
+ if __name__ == "__main__":
82
+ # Define constants
83
+ HEIGHT, WIDTH = 512, 512
84
+ # Compile model
85
+ autoencoder = build_autoencoder(HEIGHT, WIDTH)
86
+ autoencoder.summary()
 
 
 
 
 
87
  autoencoder.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())