Upload 2 files
Browse files- models/transformer-gray2color.py +62 -0
- models/unet-gray2color.py +122 -0
models/transformer-gray2color.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras import layers, models, optimizers, callbacks
|
4 |
+
from tensorflow.keras.layers import (
|
5 |
+
Input, Dense, LayerNormalization, Dropout,
|
6 |
+
MultiHeadAttention, Add, Conv2D, Reshape, UpSampling2D
|
7 |
+
)
|
8 |
+
from tensorflow.keras.models import Model
|
9 |
+
from tensorflow.keras.optimizers import Adam
|
10 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
|
11 |
+
from tensorflow.keras.mixed_precision import set_global_policy
|
12 |
+
import cv2
|
13 |
+
from skimage.color import rgb2lab, lab2rgb
|
14 |
+
from skimage.metrics import peak_signal_noise_ratio
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import glob
|
17 |
+
import os
|
18 |
+
|
19 |
+
# Define Transformer model
|
20 |
+
def transformer_model(input_shape=(1024, 1024, 1), patch_size=8,
|
21 |
+
d_model=32, num_heads=4, ff_dim=64,
|
22 |
+
num_layers=2, dropout_rate=0.1):
|
23 |
+
HEIGHT, WIDTH, _ = input_shape
|
24 |
+
num_patches = (HEIGHT // patch_size) * (WIDTH // patch_size)
|
25 |
+
|
26 |
+
inputs = Input(shape=input_shape)
|
27 |
+
|
28 |
+
# Patch extraction
|
29 |
+
x = Conv2D(d_model, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')(inputs)
|
30 |
+
x = Reshape((num_patches, d_model))(x)
|
31 |
+
|
32 |
+
# Transformer layers
|
33 |
+
for _ in range(num_layers):
|
34 |
+
attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)(x, x)
|
35 |
+
attn_output = Dropout(dropout_rate)(attn_output)
|
36 |
+
x = Add()([x, attn_output])
|
37 |
+
x = LayerNormalization(epsilon=1e-6)(x)
|
38 |
+
|
39 |
+
ff_output = Dense(ff_dim, activation='relu')(x)
|
40 |
+
ff_output = Dense(d_model)(ff_output)
|
41 |
+
ff_output = Dropout(dropout_rate)(ff_output)
|
42 |
+
x = Add()([x, ff_output])
|
43 |
+
x = LayerNormalization(epsilon=1e-6)(x)
|
44 |
+
|
45 |
+
# Decoder: Reconstruct image
|
46 |
+
x = Dense(2)(x)
|
47 |
+
x = Reshape((HEIGHT // patch_size, WIDTH // patch_size, 2))(x)
|
48 |
+
x = UpSampling2D(size=(patch_size, patch_size), interpolation='bilinear')(x)
|
49 |
+
outputs = Conv2D(2, (3, 3), activation='tanh', padding='same')(x)
|
50 |
+
|
51 |
+
return Model(inputs, outputs)
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
# Define constants
|
55 |
+
HEIGHT, WIDTH = 1024, 1024
|
56 |
+
# Instantiate and compile the model
|
57 |
+
model = transformer_model(input_shape=(HEIGHT, WIDTH, 1), patch_size=8, d_model=32,
|
58 |
+
num_heads=4, ff_dim=64, num_layers=2)
|
59 |
+
model.summary()
|
60 |
+
# Model compile
|
61 |
+
model.compile(optimizer=Adam(learning_rate=7e-5),
|
62 |
+
loss=tf.keras.losses.MeanSquaredError())
|
models/unet-gray2color.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras.layers import (
|
4 |
+
Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate,
|
5 |
+
BatchNormalization, LayerNormalization, Dropout, MultiHeadAttention, Add, Reshape
|
6 |
+
)
|
7 |
+
from tensorflow.keras.models import Model
|
8 |
+
from tensorflow.keras.optimizers import Adam
|
9 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
|
10 |
+
from tensorflow.keras.mixed_precision import set_global_policy
|
11 |
+
import cv2
|
12 |
+
import glob
|
13 |
+
import os
|
14 |
+
from skimage.color import rgb2lab, lab2rgb
|
15 |
+
from skimage.metrics import peak_signal_noise_ratio
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
# Custom self-attention layer with serialization support
|
19 |
+
@tf.keras.utils.register_keras_serializable()
|
20 |
+
class SelfAttentionLayer(Layer):
|
21 |
+
def __init__(self, num_heads, key_dim, **kwargs):
|
22 |
+
super(SelfAttentionLayer, self).__init__(**kwargs)
|
23 |
+
self.num_heads = num_heads
|
24 |
+
self.key_dim = key_dim
|
25 |
+
self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
|
26 |
+
self.ln = LayerNormalization()
|
27 |
+
|
28 |
+
def call(self, x):
|
29 |
+
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
|
30 |
+
attention_input = tf.reshape(x, [b, h * w, c])
|
31 |
+
attention_output = self.mha(attention_input, attention_input)
|
32 |
+
attention_output = tf.reshape(attention_output, [b, h, w, c])
|
33 |
+
return self.ln(x + attention_output)
|
34 |
+
|
35 |
+
def get_config(self):
|
36 |
+
config = super(SelfAttentionLayer, self).get_config()
|
37 |
+
config.update({
|
38 |
+
'num_heads': self.num_heads,
|
39 |
+
'key_dim': self.key_dim
|
40 |
+
})
|
41 |
+
return config
|
42 |
+
|
43 |
+
def attention_unet_model(input_shape=(256, 256, 1)):
|
44 |
+
inputs = Input(input_shape)
|
45 |
+
|
46 |
+
# Encoder with reduced filters
|
47 |
+
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
|
48 |
+
c1 = BatchNormalization()(c1)
|
49 |
+
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
|
50 |
+
c1 = BatchNormalization()(c1)
|
51 |
+
p1 = MaxPooling2D((2, 2))(c1)
|
52 |
+
|
53 |
+
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
|
54 |
+
c2 = BatchNormalization()(c2)
|
55 |
+
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
|
56 |
+
c2 = BatchNormalization()(c2)
|
57 |
+
p2 = MaxPooling2D((2, 2))(c2)
|
58 |
+
|
59 |
+
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
|
60 |
+
c3 = BatchNormalization()(c3)
|
61 |
+
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
|
62 |
+
c3 = BatchNormalization()(c3)
|
63 |
+
p3 = MaxPooling2D((2, 2))(c3)
|
64 |
+
|
65 |
+
# Bottleneck with reduced filters and attention
|
66 |
+
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
|
67 |
+
c4 = BatchNormalization()(c4)
|
68 |
+
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
|
69 |
+
c4 = BatchNormalization()(c4)
|
70 |
+
c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4) # Reduced heads and key_dim
|
71 |
+
|
72 |
+
# Attention gate
|
73 |
+
def attention_gate(g, s, num_filters):
|
74 |
+
g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
|
75 |
+
s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
|
76 |
+
attn = tf.keras.layers.add([g_conv, s_conv])
|
77 |
+
attn = tf.keras.layers.Activation('relu')(attn)
|
78 |
+
attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
|
79 |
+
return s * attn
|
80 |
+
|
81 |
+
# Decoder with reduced filters
|
82 |
+
u5 = UpSampling2D((2, 2))(c4)
|
83 |
+
a5 = attention_gate(u5, c3, 64)
|
84 |
+
u5 = Concatenate()([u5, a5])
|
85 |
+
c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(u5)
|
86 |
+
c5 = BatchNormalization()(c5)
|
87 |
+
c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
|
88 |
+
c5 = BatchNormalization()(c5)
|
89 |
+
|
90 |
+
u6 = UpSampling2D((2, 2))(c5)
|
91 |
+
a6 = attention_gate(u6, c2, 32)
|
92 |
+
u6 = Concatenate()([u6, a6])
|
93 |
+
c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(u6)
|
94 |
+
c6 = BatchNormalization()(c6)
|
95 |
+
c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
|
96 |
+
c6 = BatchNormalization()(c6)
|
97 |
+
|
98 |
+
u7 = UpSampling2D((2, 2))(c6)
|
99 |
+
a7 = attention_gate(u7, c1, 16)
|
100 |
+
u7 = Concatenate()([u7, a7])
|
101 |
+
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(u7)
|
102 |
+
c7 = BatchNormalization()(c7)
|
103 |
+
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
|
104 |
+
c7 = BatchNormalization()(c7)
|
105 |
+
|
106 |
+
# Output layer
|
107 |
+
outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
|
108 |
+
|
109 |
+
model = Model(inputs, outputs)
|
110 |
+
return model
|
111 |
+
|
112 |
+
# Instantiate and compile the model
|
113 |
+
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
|
114 |
+
model.summary()
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
# Define constants
|
118 |
+
HEIGHT, WIDTH = 1024, 1024
|
119 |
+
# Compile model
|
120 |
+
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
|
121 |
+
model.summary()
|
122 |
+
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|