danhtran2mind commited on
Commit
437b632
·
verified ·
1 Parent(s): ba81626

Update models/unet_gray2color.py

Browse files
Files changed (1) hide show
  1. models/unet_gray2color.py +4 -16
models/unet_gray2color.py CHANGED
@@ -14,10 +14,9 @@ from skimage.color import rgb2lab, lab2rgb
14
  from skimage.metrics import peak_signal_noise_ratio
15
  import matplotlib.pyplot as plt
16
 
17
- # Disable mixed precision to avoid dtype mismatches
18
  tf.keras.mixed_precision.set_global_policy('float32')
19
 
20
- # Custom self-attention layer with serialization support
21
  @tf.keras.utils.register_keras_serializable()
22
  class SelfAttentionLayer(Layer):
23
  def __init__(self, num_heads, key_dim, **kwargs):
@@ -28,12 +27,9 @@ class SelfAttentionLayer(Layer):
28
  self.ln = LayerNormalization(epsilon=1e-6)
29
 
30
  def build(self, input_shape):
31
- # input_shape: (batch_size, height, width, channels)
32
  batch_size, height, width, channels = input_shape
33
- attention_shape = (batch_size, height * width, channels) # Shape after reshape
34
- # Build MultiHeadAttention with query, key, and value shapes
35
  self.mha.build(query_shape=attention_shape, value_shape=attention_shape)
36
- # Build LayerNormalization with the original input shape
37
  self.ln.build(input_shape)
38
  super(SelfAttentionLayer, self).build(input_shape)
39
 
@@ -41,7 +37,6 @@ class SelfAttentionLayer(Layer):
41
  b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
42
  attention_input = tf.reshape(x, [b, h * w, c])
43
  attention_output = self.mha(attention_input, attention_input)
44
- # Cast attention_output to match x's dtype
45
  attention_output = tf.cast(attention_output, dtype=x.dtype)
46
  attention_output = tf.reshape(attention_output, [b, h, w, c])
47
  return self.ln(x + attention_output)
@@ -57,9 +52,8 @@ class SelfAttentionLayer(Layer):
57
  })
58
  return config
59
 
60
- def attention_unet_model(input_shape=(256, 256, 1)):
61
  inputs = Input(input_shape)
62
- # Encoder with reduced filters
63
  c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
64
  c1 = BatchNormalization()(c1)
65
  c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
@@ -78,14 +72,12 @@ def attention_unet_model(input_shape=(256, 256, 1)):
78
  c3 = BatchNormalization()(c3)
79
  p3 = MaxPooling2D((2, 2))(c3)
80
 
81
- # Bottleneck with reduced filters and attention
82
  c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
83
  c4 = BatchNormalization()(c4)
84
  c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
85
  c4 = BatchNormalization()(c4)
86
  c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)
87
 
88
- # Attention gate
89
  def attention_gate(g, s, num_filters):
90
  g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
91
  s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
@@ -94,7 +86,6 @@ def attention_unet_model(input_shape=(256, 256, 1)):
94
  attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
95
  return s * attn
96
 
97
- # Decoder with reduced filters
98
  u5 = UpSampling2D((2, 2))(c4)
99
  a5 = attention_gate(u5, c3, 64)
100
  u5 = Concatenate()([u5, a5])
@@ -119,15 +110,12 @@ def attention_unet_model(input_shape=(256, 256, 1)):
119
  c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
120
  c7 = BatchNormalization()(c7)
121
 
122
- # Output layer
123
  outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
124
  model = Model(inputs, outputs)
125
  return model
126
 
127
  if __name__ == "__main__":
128
- # Define constants
129
- HEIGHT, WIDTH = 256, 256 # Match function definition
130
- # Compile model
131
  model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
132
  model.summary()
133
  model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
 
14
  from skimage.metrics import peak_signal_noise_ratio
15
  import matplotlib.pyplot as plt
16
 
17
+ # Disable mixed precision
18
  tf.keras.mixed_precision.set_global_policy('float32')
19
 
 
20
  @tf.keras.utils.register_keras_serializable()
21
  class SelfAttentionLayer(Layer):
22
  def __init__(self, num_heads, key_dim, **kwargs):
 
27
  self.ln = LayerNormalization(epsilon=1e-6)
28
 
29
  def build(self, input_shape):
 
30
  batch_size, height, width, channels = input_shape
31
+ attention_shape = (batch_size, height * width, channels)
 
32
  self.mha.build(query_shape=attention_shape, value_shape=attention_shape)
 
33
  self.ln.build(input_shape)
34
  super(SelfAttentionLayer, self).build(input_shape)
35
 
 
37
  b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
38
  attention_input = tf.reshape(x, [b, h * w, c])
39
  attention_output = self.mha(attention_input, attention_input)
 
40
  attention_output = tf.cast(attention_output, dtype=x.dtype)
41
  attention_output = tf.reshape(attention_output, [b, h, w, c])
42
  return self.ln(x + attention_output)
 
52
  })
53
  return config
54
 
55
+ def attention_unet_model(input_shape=(1024, 1024, 1)):
56
  inputs = Input(input_shape)
 
57
  c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
58
  c1 = BatchNormalization()(c1)
59
  c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
 
72
  c3 = BatchNormalization()(c3)
73
  p3 = MaxPooling2D((2, 2))(c3)
74
 
 
75
  c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
76
  c4 = BatchNormalization()(c4)
77
  c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
78
  c4 = BatchNormalization()(c4)
79
  c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)
80
 
 
81
  def attention_gate(g, s, num_filters):
82
  g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
83
  s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
 
86
  attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
87
  return s * attn
88
 
 
89
  u5 = UpSampling2D((2, 2))(c4)
90
  a5 = attention_gate(u5, c3, 64)
91
  u5 = Concatenate()([u5, a5])
 
110
  c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
111
  c7 = BatchNormalization()(c7)
112
 
 
113
  outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
114
  model = Model(inputs, outputs)
115
  return model
116
 
117
  if __name__ == "__main__":
118
+ HEIGHT, WIDTH = 1024, 1024
 
 
119
  model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
120
  model.summary()
121
  model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())