File size: 2,483 Bytes
9de60a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.layers import (
    Input, Dense, LayerNormalization, Dropout,
    MultiHeadAttention, Add, Conv2D, Reshape, UpSampling2D
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.mixed_precision import set_global_policy
import cv2
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio
import matplotlib.pyplot as plt
import glob
import os

# Define Transformer model
def transformer_model(input_shape=(1024, 1024, 1), patch_size=8,
                      d_model=32, num_heads=4, ff_dim=64,
                      num_layers=2, dropout_rate=0.1):
    HEIGHT, WIDTH, _ = input_shape
    num_patches = (HEIGHT // patch_size) * (WIDTH // patch_size)
    
    inputs = Input(shape=input_shape)
    
    # Patch extraction
    x = Conv2D(d_model, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')(inputs)
    x = Reshape((num_patches, d_model))(x)
    
    # Transformer layers
    for _ in range(num_layers):
        attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)(x, x)
        attn_output = Dropout(dropout_rate)(attn_output)
        x = Add()([x, attn_output])
        x = LayerNormalization(epsilon=1e-6)(x)
        
        ff_output = Dense(ff_dim, activation='relu')(x)
        ff_output = Dense(d_model)(ff_output)
        ff_output = Dropout(dropout_rate)(ff_output)
        x = Add()([x, ff_output])
        x = LayerNormalization(epsilon=1e-6)(x)
    
    # Decoder: Reconstruct image
    x = Dense(2)(x)
    x = Reshape((HEIGHT // patch_size, WIDTH // patch_size, 2))(x)
    x = UpSampling2D(size=(patch_size, patch_size), interpolation='bilinear')(x)
    outputs = Conv2D(2, (3, 3), activation='tanh', padding='same')(x)
    
    return Model(inputs, outputs)

if __name__ == "__main__":
    # Define constants
    HEIGHT, WIDTH = 1024, 1024
    # Instantiate and compile the model
    model = transformer_model(input_shape=(HEIGHT, WIDTH, 1), patch_size=8, d_model=32,
                              num_heads=4, ff_dim=64, num_layers=2)
    model.summary()
    # Model compile
    model.compile(optimizer=Adam(learning_rate=7e-5), 
                       loss=tf.keras.losses.MeanSquaredError())