danhtran2mind commited on
Commit
055cdae
·
verified ·
1 Parent(s): 3ebaa82

Update models/unet_gray2color.py

Browse files
Files changed (1) hide show
  1. models/unet_gray2color.py +17 -45
models/unet_gray2color.py CHANGED
@@ -1,13 +1,12 @@
1
  import numpy as np
2
  import tensorflow as tf
3
  from tensorflow.keras.layers import (
4
- Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate,
5
- BatchNormalization, LayerNormalization, Dropout, MultiHeadAttention, Add, Reshape, Layer
6
  )
7
  from tensorflow.keras.models import Model
8
  from tensorflow.keras.optimizers import Adam
9
  from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
10
- from tensorflow.keras.mixed_precision import set_global_policy
11
  import cv2
12
  import glob
13
  import os
@@ -15,6 +14,9 @@ from skimage.color import rgb2lab, lab2rgb
15
  from skimage.metrics import peak_signal_noise_ratio
16
  import matplotlib.pyplot as plt
17
 
 
 
 
18
  # Custom self-attention layer with serialization support
19
  @tf.keras.utils.register_keras_serializable()
20
  class SelfAttentionLayer(Layer):
@@ -27,7 +29,6 @@ class SelfAttentionLayer(Layer):
27
 
28
  def build(self, input_shape):
29
  # input_shape: (batch_size, height, width, channels)
30
- # For self-attention, query, key, and value have the same shape
31
  batch_size, height, width, channels = input_shape
32
  attention_shape = (batch_size, height * width, channels) # Shape after reshape
33
  # Build MultiHeadAttention with query, key, and value shapes
@@ -40,7 +41,7 @@ class SelfAttentionLayer(Layer):
40
  b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
41
  attention_input = tf.reshape(x, [b, h * w, c])
42
  attention_output = self.mha(attention_input, attention_input)
43
- # Cast attention_output to match x's dtype to avoid type mismatch
44
  attention_output = tf.cast(attention_output, dtype=x.dtype)
45
  attention_output = tf.reshape(attention_output, [b, h, w, c])
46
  return self.ln(x + attention_output)
@@ -55,59 +56,35 @@ class SelfAttentionLayer(Layer):
55
  'key_dim': self.key_dim
56
  })
57
  return config
58
-
59
- # class SelfAttentionLayer(Layer):
60
- # def __init__(self, num_heads, key_dim, **kwargs):
61
- # super(SelfAttentionLayer, self).__init__(**kwargs)
62
- # self.num_heads = num_heads
63
- # self.key_dim = key_dim
64
- # self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
65
- # self.ln = LayerNormalization()
66
-
67
- # def call(self, x):
68
- # b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
69
- # attention_input = tf.reshape(x, [b, h * w, c])
70
- # attention_output = self.mha(attention_input, attention_input)
71
- # attention_output = tf.reshape(attention_output, [b, h, w, c])
72
- # return self.ln(x + attention_output)
73
-
74
- # def get_config(self):
75
- # config = super(SelfAttentionLayer, self).get_config()
76
- # config.update({
77
- # 'num_heads': self.num_heads,
78
- # 'key_dim': self.key_dim
79
- # })
80
- # return config
81
 
82
  def attention_unet_model(input_shape=(256, 256, 1)):
83
  inputs = Input(input_shape)
84
-
85
  # Encoder with reduced filters
86
  c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
87
  c1 = BatchNormalization()(c1)
88
  c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
89
  c1 = BatchNormalization()(c1)
90
  p1 = MaxPooling2D((2, 2))(c1)
91
-
92
  c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
93
  c2 = BatchNormalization()(c2)
94
  c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
95
  c2 = BatchNormalization()(c2)
96
  p2 = MaxPooling2D((2, 2))(c2)
97
-
98
  c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
99
  c3 = BatchNormalization()(c3)
100
  c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
101
  c3 = BatchNormalization()(c3)
102
  p3 = MaxPooling2D((2, 2))(c3)
103
-
104
  # Bottleneck with reduced filters and attention
105
  c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
106
  c4 = BatchNormalization()(c4)
107
  c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
108
  c4 = BatchNormalization()(c4)
109
- c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4) # Reduced heads and key_dim
110
-
111
  # Attention gate
112
  def attention_gate(g, s, num_filters):
113
  g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
@@ -116,7 +93,7 @@ def attention_unet_model(input_shape=(256, 256, 1)):
116
  attn = tf.keras.layers.Activation('relu')(attn)
117
  attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
118
  return s * attn
119
-
120
  # Decoder with reduced filters
121
  u5 = UpSampling2D((2, 2))(c4)
122
  a5 = attention_gate(u5, c3, 64)
@@ -125,7 +102,7 @@ def attention_unet_model(input_shape=(256, 256, 1)):
125
  c5 = BatchNormalization()(c5)
126
  c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
127
  c5 = BatchNormalization()(c5)
128
-
129
  u6 = UpSampling2D((2, 2))(c5)
130
  a6 = attention_gate(u6, c2, 32)
131
  u6 = Concatenate()([u6, a6])
@@ -133,7 +110,7 @@ def attention_unet_model(input_shape=(256, 256, 1)):
133
  c6 = BatchNormalization()(c6)
134
  c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
135
  c6 = BatchNormalization()(c6)
136
-
137
  u7 = UpSampling2D((2, 2))(c6)
138
  a7 = attention_gate(u7, c1, 16)
139
  u7 = Concatenate()([u7, a7])
@@ -141,21 +118,16 @@ def attention_unet_model(input_shape=(256, 256, 1)):
141
  c7 = BatchNormalization()(c7)
142
  c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
143
  c7 = BatchNormalization()(c7)
144
-
145
  # Output layer
146
  outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
147
-
148
  model = Model(inputs, outputs)
149
  return model
150
 
151
- # # Instantiate and compile the model
152
- # model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
153
- # model.summary()
154
-
155
  if __name__ == "__main__":
156
  # Define constants
157
- HEIGHT, WIDTH = 1024, 1024
158
  # Compile model
159
  model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
160
  model.summary()
161
- model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
 
1
  import numpy as np
2
  import tensorflow as tf
3
  from tensorflow.keras.layers import (
4
+ Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, BatchNormalization,
5
+ LayerNormalization, Dropout, MultiHeadAttention, Add, Reshape, Layer
6
  )
7
  from tensorflow.keras.models import Model
8
  from tensorflow.keras.optimizers import Adam
9
  from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
 
10
  import cv2
11
  import glob
12
  import os
 
14
  from skimage.metrics import peak_signal_noise_ratio
15
  import matplotlib.pyplot as plt
16
 
17
+ # Disable mixed precision to avoid dtype mismatches
18
+ tf.keras.mixed_precision.set_global_policy('float32')
19
+
20
  # Custom self-attention layer with serialization support
21
  @tf.keras.utils.register_keras_serializable()
22
  class SelfAttentionLayer(Layer):
 
29
 
30
  def build(self, input_shape):
31
  # input_shape: (batch_size, height, width, channels)
 
32
  batch_size, height, width, channels = input_shape
33
  attention_shape = (batch_size, height * width, channels) # Shape after reshape
34
  # Build MultiHeadAttention with query, key, and value shapes
 
41
  b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
42
  attention_input = tf.reshape(x, [b, h * w, c])
43
  attention_output = self.mha(attention_input, attention_input)
44
+ # Cast attention_output to match x's dtype
45
  attention_output = tf.cast(attention_output, dtype=x.dtype)
46
  attention_output = tf.reshape(attention_output, [b, h, w, c])
47
  return self.ln(x + attention_output)
 
56
  'key_dim': self.key_dim
57
  })
58
  return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def attention_unet_model(input_shape=(256, 256, 1)):
61
  inputs = Input(input_shape)
 
62
  # Encoder with reduced filters
63
  c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
64
  c1 = BatchNormalization()(c1)
65
  c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
66
  c1 = BatchNormalization()(c1)
67
  p1 = MaxPooling2D((2, 2))(c1)
68
+
69
  c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
70
  c2 = BatchNormalization()(c2)
71
  c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
72
  c2 = BatchNormalization()(c2)
73
  p2 = MaxPooling2D((2, 2))(c2)
74
+
75
  c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
76
  c3 = BatchNormalization()(c3)
77
  c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
78
  c3 = BatchNormalization()(c3)
79
  p3 = MaxPooling2D((2, 2))(c3)
80
+
81
  # Bottleneck with reduced filters and attention
82
  c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
83
  c4 = BatchNormalization()(c4)
84
  c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
85
  c4 = BatchNormalization()(c4)
86
+ c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)
87
+
88
  # Attention gate
89
  def attention_gate(g, s, num_filters):
90
  g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
 
93
  attn = tf.keras.layers.Activation('relu')(attn)
94
  attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
95
  return s * attn
96
+
97
  # Decoder with reduced filters
98
  u5 = UpSampling2D((2, 2))(c4)
99
  a5 = attention_gate(u5, c3, 64)
 
102
  c5 = BatchNormalization()(c5)
103
  c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
104
  c5 = BatchNormalization()(c5)
105
+
106
  u6 = UpSampling2D((2, 2))(c5)
107
  a6 = attention_gate(u6, c2, 32)
108
  u6 = Concatenate()([u6, a6])
 
110
  c6 = BatchNormalization()(c6)
111
  c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
112
  c6 = BatchNormalization()(c6)
113
+
114
  u7 = UpSampling2D((2, 2))(c6)
115
  a7 = attention_gate(u7, c1, 16)
116
  u7 = Concatenate()([u7, a7])
 
118
  c7 = BatchNormalization()(c7)
119
  c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
120
  c7 = BatchNormalization()(c7)
121
+
122
  # Output layer
123
  outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
 
124
  model = Model(inputs, outputs)
125
  return model
126
 
 
 
 
 
127
  if __name__ == "__main__":
128
  # Define constants
129
+ HEIGHT, WIDTH = 256, 256 # Match function definition
130
  # Compile model
131
  model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
132
  model.summary()
133
+ model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())