File size: 4,700 Bytes
9de60a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate,
    BatchNormalization, LayerNormalization, Dropout, MultiHeadAttention, Add, Reshape
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.mixed_precision import set_global_policy
import cv2
import glob
import os
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio
import matplotlib.pyplot as plt

# Custom self-attention layer with serialization support
@tf.keras.utils.register_keras_serializable()
class SelfAttentionLayer(Layer):
    def __init__(self, num_heads, key_dim, **kwargs):
        super(SelfAttentionLayer, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
        self.ln = LayerNormalization()

    def call(self, x):
        b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
        attention_input = tf.reshape(x, [b, h * w, c])
        attention_output = self.mha(attention_input, attention_input)
        attention_output = tf.reshape(attention_output, [b, h, w, c])
        return self.ln(x + attention_output)

    def get_config(self):
        config = super(SelfAttentionLayer, self).get_config()
        config.update({
            'num_heads': self.num_heads,
            'key_dim': self.key_dim
        })
        return config

def attention_unet_model(input_shape=(256, 256, 1)):
    inputs = Input(input_shape)

    # Encoder with reduced filters
    c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
    c1 = BatchNormalization()(c1)
    c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
    c1 = BatchNormalization()(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
    c2 = BatchNormalization()(c2)
    c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
    c2 = BatchNormalization()(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
    c3 = BatchNormalization()(c3)
    c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
    c3 = BatchNormalization()(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    # Bottleneck with reduced filters and attention
    c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
    c4 = BatchNormalization()(c4)
    c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
    c4 = BatchNormalization()(c4)
    c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)  # Reduced heads and key_dim

    # Attention gate
    def attention_gate(g, s, num_filters):
        g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
        s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
        attn = tf.keras.layers.add([g_conv, s_conv])
        attn = tf.keras.layers.Activation('relu')(attn)
        attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
        return s * attn

    # Decoder with reduced filters
    u5 = UpSampling2D((2, 2))(c4)
    a5 = attention_gate(u5, c3, 64)
    u5 = Concatenate()([u5, a5])
    c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(u5)
    c5 = BatchNormalization()(c5)
    c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
    c5 = BatchNormalization()(c5)

    u6 = UpSampling2D((2, 2))(c5)
    a6 = attention_gate(u6, c2, 32)
    u6 = Concatenate()([u6, a6])
    c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(u6)
    c6 = BatchNormalization()(c6)
    c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
    c6 = BatchNormalization()(c6)

    u7 = UpSampling2D((2, 2))(c6)
    a7 = attention_gate(u7, c1, 16)
    u7 = Concatenate()([u7, a7])
    c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(u7)
    c7 = BatchNormalization()(c7)
    c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
    c7 = BatchNormalization()(c7)

    # Output layer
    outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)

    model = Model(inputs, outputs)
    return model

# Instantiate and compile the model
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
model.summary()

if __name__ == "__main__":
    # Define constants
    HEIGHT, WIDTH = 1024, 1024
    # Compile model
    model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
    model.summary()
    model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError())