Update models/unet_gray2color.py
Browse files- models/unet_gray2color.py +17 -45
models/unet_gray2color.py
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
import numpy as np
|
2 |
import tensorflow as tf
|
3 |
from tensorflow.keras.layers import (
|
4 |
-
Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate,
|
5 |
-
|
6 |
)
|
7 |
from tensorflow.keras.models import Model
|
8 |
from tensorflow.keras.optimizers import Adam
|
9 |
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
|
10 |
-
from tensorflow.keras.mixed_precision import set_global_policy
|
11 |
import cv2
|
12 |
import glob
|
13 |
import os
|
@@ -15,6 +14,9 @@ from skimage.color import rgb2lab, lab2rgb
|
|
15 |
from skimage.metrics import peak_signal_noise_ratio
|
16 |
import matplotlib.pyplot as plt
|
17 |
|
|
|
|
|
|
|
18 |
# Custom self-attention layer with serialization support
|
19 |
@tf.keras.utils.register_keras_serializable()
|
20 |
class SelfAttentionLayer(Layer):
|
@@ -27,7 +29,6 @@ class SelfAttentionLayer(Layer):
|
|
27 |
|
28 |
def build(self, input_shape):
|
29 |
# input_shape: (batch_size, height, width, channels)
|
30 |
-
# For self-attention, query, key, and value have the same shape
|
31 |
batch_size, height, width, channels = input_shape
|
32 |
attention_shape = (batch_size, height * width, channels) # Shape after reshape
|
33 |
# Build MultiHeadAttention with query, key, and value shapes
|
@@ -40,7 +41,7 @@ class SelfAttentionLayer(Layer):
|
|
40 |
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
|
41 |
attention_input = tf.reshape(x, [b, h * w, c])
|
42 |
attention_output = self.mha(attention_input, attention_input)
|
43 |
-
# Cast attention_output to match x's dtype
|
44 |
attention_output = tf.cast(attention_output, dtype=x.dtype)
|
45 |
attention_output = tf.reshape(attention_output, [b, h, w, c])
|
46 |
return self.ln(x + attention_output)
|
@@ -55,59 +56,35 @@ class SelfAttentionLayer(Layer):
|
|
55 |
'key_dim': self.key_dim
|
56 |
})
|
57 |
return config
|
58 |
-
|
59 |
-
# class SelfAttentionLayer(Layer):
|
60 |
-
# def __init__(self, num_heads, key_dim, **kwargs):
|
61 |
-
# super(SelfAttentionLayer, self).__init__(**kwargs)
|
62 |
-
# self.num_heads = num_heads
|
63 |
-
# self.key_dim = key_dim
|
64 |
-
# self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
|
65 |
-
# self.ln = LayerNormalization()
|
66 |
-
|
67 |
-
# def call(self, x):
|
68 |
-
# b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
|
69 |
-
# attention_input = tf.reshape(x, [b, h * w, c])
|
70 |
-
# attention_output = self.mha(attention_input, attention_input)
|
71 |
-
# attention_output = tf.reshape(attention_output, [b, h, w, c])
|
72 |
-
# return self.ln(x + attention_output)
|
73 |
-
|
74 |
-
# def get_config(self):
|
75 |
-
# config = super(SelfAttentionLayer, self).get_config()
|
76 |
-
# config.update({
|
77 |
-
# 'num_heads': self.num_heads,
|
78 |
-
# 'key_dim': self.key_dim
|
79 |
-
# })
|
80 |
-
# return config
|
81 |
|
82 |
def attention_unet_model(input_shape=(256, 256, 1)):
|
83 |
inputs = Input(input_shape)
|
84 |
-
|
85 |
# Encoder with reduced filters
|
86 |
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
|
87 |
c1 = BatchNormalization()(c1)
|
88 |
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
|
89 |
c1 = BatchNormalization()(c1)
|
90 |
p1 = MaxPooling2D((2, 2))(c1)
|
91 |
-
|
92 |
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
|
93 |
c2 = BatchNormalization()(c2)
|
94 |
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
|
95 |
c2 = BatchNormalization()(c2)
|
96 |
p2 = MaxPooling2D((2, 2))(c2)
|
97 |
-
|
98 |
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
|
99 |
c3 = BatchNormalization()(c3)
|
100 |
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
|
101 |
c3 = BatchNormalization()(c3)
|
102 |
p3 = MaxPooling2D((2, 2))(c3)
|
103 |
-
|
104 |
# Bottleneck with reduced filters and attention
|
105 |
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
|
106 |
c4 = BatchNormalization()(c4)
|
107 |
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
|
108 |
c4 = BatchNormalization()(c4)
|
109 |
-
c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)
|
110 |
-
|
111 |
# Attention gate
|
112 |
def attention_gate(g, s, num_filters):
|
113 |
g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
|
@@ -116,7 +93,7 @@ def attention_unet_model(input_shape=(256, 256, 1)):
|
|
116 |
attn = tf.keras.layers.Activation('relu')(attn)
|
117 |
attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
|
118 |
return s * attn
|
119 |
-
|
120 |
# Decoder with reduced filters
|
121 |
u5 = UpSampling2D((2, 2))(c4)
|
122 |
a5 = attention_gate(u5, c3, 64)
|
@@ -125,7 +102,7 @@ def attention_unet_model(input_shape=(256, 256, 1)):
|
|
125 |
c5 = BatchNormalization()(c5)
|
126 |
c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
|
127 |
c5 = BatchNormalization()(c5)
|
128 |
-
|
129 |
u6 = UpSampling2D((2, 2))(c5)
|
130 |
a6 = attention_gate(u6, c2, 32)
|
131 |
u6 = Concatenate()([u6, a6])
|
@@ -133,7 +110,7 @@ def attention_unet_model(input_shape=(256, 256, 1)):
|
|
133 |
c6 = BatchNormalization()(c6)
|
134 |
c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
|
135 |
c6 = BatchNormalization()(c6)
|
136 |
-
|
137 |
u7 = UpSampling2D((2, 2))(c6)
|
138 |
a7 = attention_gate(u7, c1, 16)
|
139 |
u7 = Concatenate()([u7, a7])
|
@@ -141,21 +118,16 @@ def attention_unet_model(input_shape=(256, 256, 1)):
|
|
141 |
c7 = BatchNormalization()(c7)
|
142 |
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
|
143 |
c7 = BatchNormalization()(c7)
|
144 |
-
|
145 |
# Output layer
|
146 |
outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
|
147 |
-
|
148 |
model = Model(inputs, outputs)
|
149 |
return model
|
150 |
|
151 |
-
# # Instantiate and compile the model
|
152 |
-
# model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
|
153 |
-
# model.summary()
|
154 |
-
|
155 |
if __name__ == "__main__":
|
156 |
# Define constants
|
157 |
-
HEIGHT, WIDTH =
|
158 |
# Compile model
|
159 |
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
|
160 |
model.summary()
|
161 |
-
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|
|
|
1 |
import numpy as np
|
2 |
import tensorflow as tf
|
3 |
from tensorflow.keras.layers import (
|
4 |
+
Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, BatchNormalization,
|
5 |
+
LayerNormalization, Dropout, MultiHeadAttention, Add, Reshape, Layer
|
6 |
)
|
7 |
from tensorflow.keras.models import Model
|
8 |
from tensorflow.keras.optimizers import Adam
|
9 |
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
|
|
|
10 |
import cv2
|
11 |
import glob
|
12 |
import os
|
|
|
14 |
from skimage.metrics import peak_signal_noise_ratio
|
15 |
import matplotlib.pyplot as plt
|
16 |
|
17 |
+
# Disable mixed precision to avoid dtype mismatches
|
18 |
+
tf.keras.mixed_precision.set_global_policy('float32')
|
19 |
+
|
20 |
# Custom self-attention layer with serialization support
|
21 |
@tf.keras.utils.register_keras_serializable()
|
22 |
class SelfAttentionLayer(Layer):
|
|
|
29 |
|
30 |
def build(self, input_shape):
|
31 |
# input_shape: (batch_size, height, width, channels)
|
|
|
32 |
batch_size, height, width, channels = input_shape
|
33 |
attention_shape = (batch_size, height * width, channels) # Shape after reshape
|
34 |
# Build MultiHeadAttention with query, key, and value shapes
|
|
|
41 |
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
|
42 |
attention_input = tf.reshape(x, [b, h * w, c])
|
43 |
attention_output = self.mha(attention_input, attention_input)
|
44 |
+
# Cast attention_output to match x's dtype
|
45 |
attention_output = tf.cast(attention_output, dtype=x.dtype)
|
46 |
attention_output = tf.reshape(attention_output, [b, h, w, c])
|
47 |
return self.ln(x + attention_output)
|
|
|
56 |
'key_dim': self.key_dim
|
57 |
})
|
58 |
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def attention_unet_model(input_shape=(256, 256, 1)):
|
61 |
inputs = Input(input_shape)
|
|
|
62 |
# Encoder with reduced filters
|
63 |
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
|
64 |
c1 = BatchNormalization()(c1)
|
65 |
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
|
66 |
c1 = BatchNormalization()(c1)
|
67 |
p1 = MaxPooling2D((2, 2))(c1)
|
68 |
+
|
69 |
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
|
70 |
c2 = BatchNormalization()(c2)
|
71 |
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
|
72 |
c2 = BatchNormalization()(c2)
|
73 |
p2 = MaxPooling2D((2, 2))(c2)
|
74 |
+
|
75 |
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
|
76 |
c3 = BatchNormalization()(c3)
|
77 |
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
|
78 |
c3 = BatchNormalization()(c3)
|
79 |
p3 = MaxPooling2D((2, 2))(c3)
|
80 |
+
|
81 |
# Bottleneck with reduced filters and attention
|
82 |
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
|
83 |
c4 = BatchNormalization()(c4)
|
84 |
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
|
85 |
c4 = BatchNormalization()(c4)
|
86 |
+
c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)
|
87 |
+
|
88 |
# Attention gate
|
89 |
def attention_gate(g, s, num_filters):
|
90 |
g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
|
|
|
93 |
attn = tf.keras.layers.Activation('relu')(attn)
|
94 |
attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
|
95 |
return s * attn
|
96 |
+
|
97 |
# Decoder with reduced filters
|
98 |
u5 = UpSampling2D((2, 2))(c4)
|
99 |
a5 = attention_gate(u5, c3, 64)
|
|
|
102 |
c5 = BatchNormalization()(c5)
|
103 |
c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
|
104 |
c5 = BatchNormalization()(c5)
|
105 |
+
|
106 |
u6 = UpSampling2D((2, 2))(c5)
|
107 |
a6 = attention_gate(u6, c2, 32)
|
108 |
u6 = Concatenate()([u6, a6])
|
|
|
110 |
c6 = BatchNormalization()(c6)
|
111 |
c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
|
112 |
c6 = BatchNormalization()(c6)
|
113 |
+
|
114 |
u7 = UpSampling2D((2, 2))(c6)
|
115 |
a7 = attention_gate(u7, c1, 16)
|
116 |
u7 = Concatenate()([u7, a7])
|
|
|
118 |
c7 = BatchNormalization()(c7)
|
119 |
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
|
120 |
c7 = BatchNormalization()(c7)
|
121 |
+
|
122 |
# Output layer
|
123 |
outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
|
|
|
124 |
model = Model(inputs, outputs)
|
125 |
return model
|
126 |
|
|
|
|
|
|
|
|
|
127 |
if __name__ == "__main__":
|
128 |
# Define constants
|
129 |
+
HEIGHT, WIDTH = 256, 256 # Match function definition
|
130 |
# Compile model
|
131 |
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
|
132 |
model.summary()
|
133 |
+
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|