File size: 4,841 Bytes
9de60a1 055cdae 9de60a1 2a9c496 9de60a1 437b632 055cdae 9de60a1 61e51d0 772dbaa 437b632 772dbaa 61e51d0 9de60a1 2124ba9 9de60a1 2124ba9 9de60a1 437b632 9de60a1 055cdae 9de60a1 055cdae 9de60a1 055cdae 9de60a1 055cdae 9de60a1 055cdae 9de60a1 055cdae 9de60a1 055cdae 9de60a1 055cdae 9de60a1 437b632 9de60a1 055cdae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, BatchNormalization,
LayerNormalization, Dropout, MultiHeadAttention, Add, Reshape, Layer
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
import cv2
import glob
import os
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio
import matplotlib.pyplot as plt
# Disable mixed precision
tf.keras.mixed_precision.set_global_policy('float32')
@tf.keras.utils.register_keras_serializable()
class SelfAttentionLayer(Layer):
def __init__(self, num_heads, key_dim, **kwargs):
super(SelfAttentionLayer, self).__init__(**kwargs)
self.num_heads = num_heads
self.key_dim = key_dim
self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
self.ln = LayerNormalization(epsilon=1e-6)
def build(self, input_shape):
batch_size, height, width, channels = input_shape
attention_shape = (batch_size, height * width, channels)
self.mha.build(query_shape=attention_shape, value_shape=attention_shape)
self.ln.build(input_shape)
super(SelfAttentionLayer, self).build(input_shape)
def call(self, x):
b, h, w, c = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
attention_input = tf.reshape(x, [b, h * w, c])
attention_output = self.mha(attention_input, attention_input)
attention_output = tf.cast(attention_output, dtype=x.dtype)
attention_output = tf.reshape(attention_output, [b, h, w, c])
return self.ln(x + attention_output)
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = super(SelfAttentionLayer, self).get_config()
config.update({
'num_heads': self.num_heads,
'key_dim': self.key_dim
})
return config
def attention_unet_model(input_shape=(1024, 1024, 1)):
inputs = Input(input_shape)
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
c1 = BatchNormalization()(c1)
c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
c1 = BatchNormalization()(c1)
p1 = MaxPooling2D((2, 2))(c1)
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
c2 = BatchNormalization()(c2)
c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
c2 = BatchNormalization()(c2)
p2 = MaxPooling2D((2, 2))(c2)
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
c3 = BatchNormalization()(c3)
c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
c3 = BatchNormalization()(c3)
p3 = MaxPooling2D((2, 2))(c3)
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
c4 = BatchNormalization()(c4)
c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
c4 = BatchNormalization()(c4)
c4 = SelfAttentionLayer(num_heads=2, key_dim=32)(c4)
def attention_gate(g, s, num_filters):
g_conv = Conv2D(num_filters, (1, 1), padding='same')(g)
s_conv = Conv2D(num_filters, (1, 1), padding='same')(s)
attn = tf.keras.layers.add([g_conv, s_conv])
attn = tf.keras.layers.Activation('relu')(attn)
attn = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(attn)
return s * attn
u5 = UpSampling2D((2, 2))(c4)
a5 = attention_gate(u5, c3, 64)
u5 = Concatenate()([u5, a5])
c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(u5)
c5 = BatchNormalization()(c5)
c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)
c5 = BatchNormalization()(c5)
u6 = UpSampling2D((2, 2))(c5)
a6 = attention_gate(u6, c2, 32)
u6 = Concatenate()([u6, a6])
c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(u6)
c6 = BatchNormalization()(c6)
c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)
c6 = BatchNormalization()(c6)
u7 = UpSampling2D((2, 2))(c6)
a7 = attention_gate(u7, c1, 16)
u7 = Concatenate()([u7, a7])
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(u7)
c7 = BatchNormalization()(c7)
c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)
c7 = BatchNormalization()(c7)
outputs = Conv2D(2, (1, 1), activation='tanh', padding='same')(c7)
model = Model(inputs, outputs)
return model
if __name__ == "__main__":
HEIGHT, WIDTH = 1024, 1024
model = attention_unet_model(input_shape=(HEIGHT, WIDTH, 1))
model.summary()
model.compile(optimizer=Adam(learning_rate=7e-5), loss=tf.keras.losses.MeanSquaredError()) |