File size: 4,127 Bytes
045c942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Model hyperparameters
num_classes = 3
input_shape = (256, 256, 3)
image_size = 256
patch_size = 16
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 8  # Increased from 4 β†’ 8
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 12  # Increased from 8 β†’ 12
mlp_head_units = [2048, 1024]
dropout_rate = 0.1

# Data Augmentation with stronger transformations
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.1),
        layers.RandomZoom(0.2, 0.2),
    ],
    name="data_augmentation",
)

# Patch Creation Layer
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        return tf.reshape(patches, [batch_size, -1, patch_dims])

# Patch Encoding with Learnable [CLS] Token
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.projection = layers.Dense(projection_dim)
        self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True)
        self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim)

    def call(self, patch):
        batch_size = tf.shape(patch)[0]
        cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim])
        patch = self.projection(patch)
        patch = tf.concat([cls_token, patch], axis=1)  # Add CLS token
        positions = tf.range(start=0, limit=num_patches + 1, delta=1)
        return patch + self.position_embedding(positions)

# MLP Block
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
        x = layers.LayerNormalization()(x)  # Added LayerNorm
    return x

# Vision Transformer Model
def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization()(encoded_patches)
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization()(x2)
        x3 = mlp(x3, transformer_units, dropout_rate)
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization()(encoded_patches)
    cls_output = representation[:, 0, :]  # Extract CLS token representation
    features = mlp(cls_output, mlp_head_units, dropout_rate)
    outputs = layers.Dense(num_classes, activation="softmax")(features)

    return keras.Model(inputs, outputs)

# Cosine Decay Learning Rate Scheduler
def cosine_decay_schedule(initial_lr, total_steps):
    return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps)

# Optimizer with Weight Decay & Cosine Decay
learning_rate = 3e-4
weight_decay = 0.03
num_epochs = 50

optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay)

# Compile Model
vit_classifier = create_vit_classifier()
vit_classifier.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# Print Model Summary
vit_classifier.summary()