import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers # Model hyperparameters num_classes = 3 input_shape = (256, 256, 3) image_size = 256 patch_size = 16 num_patches = (image_size // patch_size) ** 2 projection_dim = 64 num_heads = 8 # Increased from 4 → 8 transformer_units = [projection_dim * 2, projection_dim] transformer_layers = 12 # Increased from 8 → 12 mlp_head_units = [2048, 1024] dropout_rate = 0.1 # Data Augmentation with stronger transformations data_augmentation = keras.Sequential( [ layers.Normalization(), layers.Resizing(image_size, image_size), layers.RandomFlip("horizontal"), layers.RandomRotation(factor=0.1), layers.RandomZoom(0.2, 0.2), ], name="data_augmentation", ) # Patch Creation Layer class Patches(layers.Layer): def __init__(self, patch_size): super().__init__() self.patch_size = patch_size def call(self, images): batch_size = tf.shape(images)[0] patches = tf.image.extract_patches( images=images, sizes=[1, self.patch_size, self.patch_size, 1], strides=[1, self.patch_size, self.patch_size, 1], rates=[1, 1, 1, 1], padding="VALID", ) patch_dims = patches.shape[-1] return tf.reshape(patches, [batch_size, -1, patch_dims]) # Patch Encoding with Learnable [CLS] Token class PatchEncoder(layers.Layer): def __init__(self, num_patches, projection_dim): super().__init__() self.projection = layers.Dense(projection_dim) self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True) self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim) def call(self, patch): batch_size = tf.shape(patch)[0] cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim]) patch = self.projection(patch) patch = tf.concat([cls_token, patch], axis=1) # Add CLS token positions = tf.range(start=0, limit=num_patches + 1, delta=1) return patch + self.position_embedding(positions) # MLP Block def mlp(x, hidden_units, dropout_rate): for units in hidden_units: x = layers.Dense(units, activation=tf.nn.gelu)(x) x = layers.Dropout(dropout_rate)(x) x = layers.LayerNormalization()(x) # Added LayerNorm return x # Vision Transformer Model def create_vit_classifier(): inputs = layers.Input(shape=input_shape) augmented = data_augmentation(inputs) patches = Patches(patch_size)(augmented) encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) for _ in range(transformer_layers): x1 = layers.LayerNormalization()(encoded_patches) attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1) x2 = layers.Add()([attention_output, encoded_patches]) x3 = layers.LayerNormalization()(x2) x3 = mlp(x3, transformer_units, dropout_rate) encoded_patches = layers.Add()([x3, x2]) representation = layers.LayerNormalization()(encoded_patches) cls_output = representation[:, 0, :] # Extract CLS token representation features = mlp(cls_output, mlp_head_units, dropout_rate) outputs = layers.Dense(num_classes, activation="softmax")(features) return keras.Model(inputs, outputs) # Cosine Decay Learning Rate Scheduler def cosine_decay_schedule(initial_lr, total_steps): return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps) # Optimizer with Weight Decay & Cosine Decay learning_rate = 3e-4 weight_decay = 0.03 num_epochs = 50 optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay) # Compile Model vit_classifier = create_vit_classifier() vit_classifier.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"]) # Print Model Summary vit_classifier.summary()