Update models/unet_gray2color.py
Browse files- models/unet_gray2color.py +4 -16
models/unet_gray2color.py
CHANGED
@@ -14,10 +14,9 @@ from skimage.color import rgb2lab, lab2rgb
|
|
14 |
from skimage.metrics import peak_signal_noise_ratio
|
15 |
import matplotlib.pyplot as plt
|
16 |
|
17 |
-
# Disable mixed precision
|
18 |
tf.keras.mixed_precision.set_global_policy('float32')
|
19 |
|
20 |
-
# Custom self-attention layer with serialization support
|
21 |
@tf.keras.utils.register_keras_serializable()
|
22 |
class SelfAttentionLayer(Layer):
|
23 |
def __init__(self, num_heads, key_dim, **kwargs):
|
@@ -28,12 +27,9 @@ class SelfAttentionLayer(Layer):
|
|
28 |
self.ln = LayerNormalization(epsilon=1e-6)
|
29 |
|
30 |
def build(self, input_shape):
|
31 |
-
# input_shape: (batch_size, height, width, channels)
|
32 |
batch_size, height, width, channels = input_shape
|
33 |
-
attention_shape = (batch_size, height * width, channels)
|
34 |
-
# Build MultiHeadAttention with query, key, and value shapes
|
35 |
self.mha.build(query_shape=attention_shape, value_shape=attention_shape)
|
36 |
-
# Build LayerNormalization with the original input shape
|
37 |
self.ln.build(input_shape)
|
38 |
super(SelfAttentionLayer, self).build(input_shape)
|
39 |
|
@@ -41,7 +37,6 @@ class SelfAttentionLayer(Layer):
|
|
41 |
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
|
42 |
attention_input = tf.reshape(x, [b, h * w, c])
|
43 |
attention_output = self.mha(attention_input, attention_input)
|
44 |
-
# Cast attention_output to match x's dtype
|
45 |
attention_output = tf.cast(attention_output, dtype=x.dtype)
|
46 |
attention_output = tf.reshape(attention_output, [b, h, w, c])
|
47 |
return self.ln(x + attention_output)
|
@@ -57,9 +52,8 @@ class SelfAttentionLayer(Layer):
|
|
57 |
})
|
58 |
return config
|
59 |
|
60 |
-
def attention_unet_model(input_shape=(
|
61 |
inputs = Input(input_shape)
|
62 |
-
# Encoder with reduced filters
|
63 |
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
|
64 |
c1 = BatchNormalization()(c1)
|
65 |
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
|
@@ -78,14 +72,12 @@ def attention_unet_model(input_shape=(256, 256, 1)):
|
|
78 |
c3 = BatchNormalization()(c3)
|
79 |
p3 = MaxPooling2D((2, 2))(c3)
|
80 |
|
81 |
-
# Bottleneck with reduced filters and attention
|
82 |
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
|
83 |
c4 = BatchNormalization()(c4)
|
84 |
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
|
85 |
c4 = BatchNormalization()(c4)
|
86 |
c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)
|
87 |
|
88 |
-
# Attention gate
|
89 |
def attention_gate(g, s, num_filters):
|
90 |
g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
|
91 |
s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
|
@@ -94,7 +86,6 @@ def attention_unet_model(input_shape=(256, 256, 1)):
|
|
94 |
attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
|
95 |
return s * attn
|
96 |
|
97 |
-
# Decoder with reduced filters
|
98 |
u5 = UpSampling2D((2, 2))(c4)
|
99 |
a5 = attention_gate(u5, c3, 64)
|
100 |
u5 = Concatenate()([u5, a5])
|
@@ -119,15 +110,12 @@ def attention_unet_model(input_shape=(256, 256, 1)):
|
|
119 |
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
|
120 |
c7 = BatchNormalization()(c7)
|
121 |
|
122 |
-
# Output layer
|
123 |
outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
|
124 |
model = Model(inputs, outputs)
|
125 |
return model
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
-
|
129 |
-
HEIGHT, WIDTH = 256, 256 # Match function definition
|
130 |
-
# Compile model
|
131 |
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
|
132 |
model.summary()
|
133 |
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|
|
|
14 |
from skimage.metrics import peak_signal_noise_ratio
|
15 |
import matplotlib.pyplot as plt
|
16 |
|
17 |
+
# Disable mixed precision
|
18 |
tf.keras.mixed_precision.set_global_policy('float32')
|
19 |
|
|
|
20 |
@tf.keras.utils.register_keras_serializable()
|
21 |
class SelfAttentionLayer(Layer):
|
22 |
def __init__(self, num_heads, key_dim, **kwargs):
|
|
|
27 |
self.ln = LayerNormalization(epsilon=1e-6)
|
28 |
|
29 |
def build(self, input_shape):
|
|
|
30 |
batch_size, height, width, channels = input_shape
|
31 |
+
attention_shape = (batch_size, height * width, channels)
|
|
|
32 |
self.mha.build(query_shape=attention_shape, value_shape=attention_shape)
|
|
|
33 |
self.ln.build(input_shape)
|
34 |
super(SelfAttentionLayer, self).build(input_shape)
|
35 |
|
|
|
37 |
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
|
38 |
attention_input = tf.reshape(x, [b, h * w, c])
|
39 |
attention_output = self.mha(attention_input, attention_input)
|
|
|
40 |
attention_output = tf.cast(attention_output, dtype=x.dtype)
|
41 |
attention_output = tf.reshape(attention_output, [b, h, w, c])
|
42 |
return self.ln(x + attention_output)
|
|
|
52 |
})
|
53 |
return config
|
54 |
|
55 |
+
def attention_unet_model(input_shape=(1024, 1024, 1)):
|
56 |
inputs = Input(input_shape)
|
|
|
57 |
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
|
58 |
c1 = BatchNormalization()(c1)
|
59 |
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
|
|
|
72 |
c3 = BatchNormalization()(c3)
|
73 |
p3 = MaxPooling2D((2, 2))(c3)
|
74 |
|
|
|
75 |
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
|
76 |
c4 = BatchNormalization()(c4)
|
77 |
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
|
78 |
c4 = BatchNormalization()(c4)
|
79 |
c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)
|
80 |
|
|
|
81 |
def attention_gate(g, s, num_filters):
|
82 |
g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
|
83 |
s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
|
|
|
86 |
attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
|
87 |
return s * attn
|
88 |
|
|
|
89 |
u5 = UpSampling2D((2, 2))(c4)
|
90 |
a5 = attention_gate(u5, c3, 64)
|
91 |
u5 = Concatenate()([u5, a5])
|
|
|
110 |
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
|
111 |
c7 = BatchNormalization()(c7)
|
112 |
|
|
|
113 |
outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
|
114 |
model = Model(inputs, outputs)
|
115 |
return model
|
116 |
|
117 |
if __name__ == "__main__":
|
118 |
+
HEIGHT, WIDTH = 1024, 1024
|
|
|
|
|
119 |
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
|
120 |
model.summary()
|
121 |
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|