File size: 4,700 Bytes
9de60a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate,
BatchNormalization, LayerNormalization, Dropout, MultiHeadAttention, Add, Reshape
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.mixed_precision import set_global_policy
import cv2
import glob
import os
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio
import matplotlib.pyplot as plt
# Custom self-attention layer with serialization support
@tf.keras.utils.register_keras_serializable()
class SelfAttentionLayer(Layer):
def __init__(self, num_heads, key_dim, **kwargs):
super(SelfAttentionLayer, self).__init__(**kwargs)
self.num_heads = num_heads
self.key_dim = key_dim
self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
self.ln = LayerNormalization()
def call(self, x):
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
attention_input = tf.reshape(x, [b, h * w, c])
attention_output = self.mha(attention_input, attention_input)
attention_output = tf.reshape(attention_output, [b, h, w, c])
return self.ln(x + attention_output)
def get_config(self):
config = super(SelfAttentionLayer, self).get_config()
config.update({
'num_heads': self.num_heads,
'key_dim': self.key_dim
})
return config
def attention_unet_model(input_shape=(256, 256, 1)):
inputs = Input(input_shape)
# Encoder with reduced filters
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
c1 = BatchNormalization()(c1)
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
c1 = BatchNormalization()(c1)
p1 = MaxPooling2D((2, 2))(c1)
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
c2 = BatchNormalization()(c2)
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
c2 = BatchNormalization()(c2)
p2 = MaxPooling2D((2, 2))(c2)
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
c3 = BatchNormalization()(c3)
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
c3 = BatchNormalization()(c3)
p3 = MaxPooling2D((2, 2))(c3)
# Bottleneck with reduced filters and attention
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
c4 = BatchNormalization()(c4)
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
c4 = BatchNormalization()(c4)
c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4) # Reduced heads and key_dim
# Attention gate
def attention_gate(g, s, num_filters):
g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
attn = tf.keras.layers.add([g_conv, s_conv])
attn = tf.keras.layers.Activation('relu')(attn)
attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
return s * attn
# Decoder with reduced filters
u5 = UpSampling2D((2, 2))(c4)
a5 = attention_gate(u5, c3, 64)
u5 = Concatenate()([u5, a5])
c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(u5)
c5 = BatchNormalization()(c5)
c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
c5 = BatchNormalization()(c5)
u6 = UpSampling2D((2, 2))(c5)
a6 = attention_gate(u6, c2, 32)
u6 = Concatenate()([u6, a6])
c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(u6)
c6 = BatchNormalization()(c6)
c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
c6 = BatchNormalization()(c6)
u7 = UpSampling2D((2, 2))(c6)
a7 = attention_gate(u7, c1, 16)
u7 = Concatenate()([u7, a7])
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(u7)
c7 = BatchNormalization()(c7)
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
c7 = BatchNormalization()(c7)
# Output layer
outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
model = Model(inputs, outputs)
return model
# Instantiate and compile the model
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
model.summary()
if __name__ == "__main__":
# Define constants
HEIGHT, WIDTH = 1024, 1024
# Compile model
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
model.summary()
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())
|