danhtran2mind commited on
Commit
9de60a1
·
verified ·
1 Parent(s): 4e42141

Upload 2 files

Browse files
models/transformer-gray2color.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras import layers, models, optimizers, callbacks
4
+ from tensorflow.keras.layers import (
5
+ Input, Dense, LayerNormalization, Dropout,
6
+ MultiHeadAttention, Add, Conv2D, Reshape, UpSampling2D
7
+ )
8
+ from tensorflow.keras.models import Model
9
+ from tensorflow.keras.optimizers import Adam
10
+ from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
11
+ from tensorflow.keras.mixed_precision import set_global_policy
12
+ import cv2
13
+ from skimage.color import rgb2lab, lab2rgb
14
+ from skimage.metrics import peak_signal_noise_ratio
15
+ import matplotlib.pyplot as plt
16
+ import glob
17
+ import os
18
+
19
+ # Define Transformer model
20
+ def transformer_model(input_shape=(1024, 1024, 1), patch_size=8,
21
+ d_model=32, num_heads=4, ff_dim=64,
22
+ num_layers=2, dropout_rate=0.1):
23
+ HEIGHT, WIDTH, _ = input_shape
24
+ num_patches = (HEIGHT // patch_size) * (WIDTH // patch_size)
25
+
26
+ inputs = Input(shape=input_shape)
27
+
28
+ # Patch extraction
29
+ x = Conv2D(d_model, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')(inputs)
30
+ x = Reshape((num_patches, d_model))(x)
31
+
32
+ # Transformer layers
33
+ for _ in range(num_layers):
34
+ attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)(x, x)
35
+ attn_output = Dropout(dropout_rate)(attn_output)
36
+ x = Add()([x, attn_output])
37
+ x = LayerNormalization(epsilon=1e-6)(x)
38
+
39
+ ff_output = Dense(ff_dim, activation='relu')(x)
40
+ ff_output = Dense(d_model)(ff_output)
41
+ ff_output = Dropout(dropout_rate)(ff_output)
42
+ x = Add()([x, ff_output])
43
+ x = LayerNormalization(epsilon=1e-6)(x)
44
+
45
+ # Decoder: Reconstruct image
46
+ x = Dense(2)(x)
47
+ x = Reshape((HEIGHT // patch_size, WIDTH // patch_size, 2))(x)
48
+ x = UpSampling2D(size=(patch_size, patch_size), interpolation='bilinear')(x)
49
+ outputs = Conv2D(2, (3, 3), activation='tanh', padding='same')(x)
50
+
51
+ return Model(inputs, outputs)
52
+
53
+ if __name__ == "__main__":
54
+ # Define constants
55
+ HEIGHT, WIDTH = 1024, 1024
56
+ # Instantiate and compile the model
57
+ model = transformer_model(input_shape=(HEIGHT, WIDTH, 1), patch_size=8, d_model=32,
58
+ num_heads=4, ff_dim=64, num_layers=2)
59
+ model.summary()
60
+ # Model compile
61
+ model.compile(optimizer=Adam(learning_rate=7e-5),
62
+ loss=tf.keras.losses.MeanSquaredError())
models/unet-gray2color.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras.layers import (
4
+ Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate,
5
+ BatchNormalization, LayerNormalization, Dropout, MultiHeadAttention, Add, Reshape
6
+ )
7
+ from tensorflow.keras.models import Model
8
+ from tensorflow.keras.optimizers import Adam
9
+ from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
10
+ from tensorflow.keras.mixed_precision import set_global_policy
11
+ import cv2
12
+ import glob
13
+ import os
14
+ from skimage.color import rgb2lab, lab2rgb
15
+ from skimage.metrics import peak_signal_noise_ratio
16
+ import matplotlib.pyplot as plt
17
+
18
+ # Custom self-attention layer with serialization support
19
+ @tf.keras.utils.register_keras_serializable()
20
+ class SelfAttentionLayer(Layer):
21
+ def __init__(self, num_heads, key_dim, **kwargs):
22
+ super(SelfAttentionLayer, self).__init__(**kwargs)
23
+ self.num_heads = num_heads
24
+ self.key_dim = key_dim
25
+ self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
26
+ self.ln = LayerNormalization()
27
+
28
+ def call(self, x):
29
+ b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
30
+ attention_input = tf.reshape(x, [b, h * w, c])
31
+ attention_output = self.mha(attention_input, attention_input)
32
+ attention_output = tf.reshape(attention_output, [b, h, w, c])
33
+ return self.ln(x + attention_output)
34
+
35
+ def get_config(self):
36
+ config = super(SelfAttentionLayer, self).get_config()
37
+ config.update({
38
+ 'num_heads': self.num_heads,
39
+ 'key_dim': self.key_dim
40
+ })
41
+ return config
42
+
43
+ def attention_unet_model(input_shape=(256, 256, 1)):
44
+ inputs = Input(input_shape)
45
+
46
+ # Encoder with reduced filters
47
+ c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
48
+ c1 = BatchNormalization()(c1)
49
+ c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
50
+ c1 = BatchNormalization()(c1)
51
+ p1 = MaxPooling2D((2, 2))(c1)
52
+
53
+ c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
54
+ c2 = BatchNormalization()(c2)
55
+ c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
56
+ c2 = BatchNormalization()(c2)
57
+ p2 = MaxPooling2D((2, 2))(c2)
58
+
59
+ c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
60
+ c3 = BatchNormalization()(c3)
61
+ c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
62
+ c3 = BatchNormalization()(c3)
63
+ p3 = MaxPooling2D((2, 2))(c3)
64
+
65
+ # Bottleneck with reduced filters and attention
66
+ c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
67
+ c4 = BatchNormalization()(c4)
68
+ c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
69
+ c4 = BatchNormalization()(c4)
70
+ c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4) # Reduced heads and key_dim
71
+
72
+ # Attention gate
73
+ def attention_gate(g, s, num_filters):
74
+ g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
75
+ s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
76
+ attn = tf.keras.layers.add([g_conv, s_conv])
77
+ attn = tf.keras.layers.Activation('relu')(attn)
78
+ attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
79
+ return s * attn
80
+
81
+ # Decoder with reduced filters
82
+ u5 = UpSampling2D((2, 2))(c4)
83
+ a5 = attention_gate(u5, c3, 64)
84
+ u5 = Concatenate()([u5, a5])
85
+ c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(u5)
86
+ c5 = BatchNormalization()(c5)
87
+ c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
88
+ c5 = BatchNormalization()(c5)
89
+
90
+ u6 = UpSampling2D((2, 2))(c5)
91
+ a6 = attention_gate(u6, c2, 32)
92
+ u6 = Concatenate()([u6, a6])
93
+ c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(u6)
94
+ c6 = BatchNormalization()(c6)
95
+ c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
96
+ c6 = BatchNormalization()(c6)
97
+
98
+ u7 = UpSampling2D((2, 2))(c6)
99
+ a7 = attention_gate(u7, c1, 16)
100
+ u7 = Concatenate()([u7, a7])
101
+ c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(u7)
102
+ c7 = BatchNormalization()(c7)
103
+ c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
104
+ c7 = BatchNormalization()(c7)
105
+
106
+ # Output layer
107
+ outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
108
+
109
+ model = Model(inputs, outputs)
110
+ return model
111
+
112
+ # Instantiate and compile the model
113
+ model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
114
+ model.summary()
115
+
116
+ if __name__ == "__main__":
117
+ # Define constants
118
+ HEIGHT, WIDTH = 1024, 1024
119
+ # Compile model
120
+ model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
121
+ model.summary()
122
+ model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())