Sobit's picture
Create vit.py
045c942 verified
raw
history blame
4.13 kB
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Model hyperparameters
num_classes = 3
input_shape = (256, 256, 3)
image_size = 256
patch_size = 16
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 8 # Increased from 4 β†’ 8
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 12 # Increased from 8 β†’ 12
mlp_head_units = [2048, 1024]
dropout_rate = 0.1
# Data Augmentation with stronger transformations
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.Resizing(image_size, image_size),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.1),
layers.RandomZoom(0.2, 0.2),
],
name="data_augmentation",
)
# Patch Creation Layer
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
return tf.reshape(patches, [batch_size, -1, patch_dims])
# Patch Encoding with Learnable [CLS] Token
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.projection = layers.Dense(projection_dim)
self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True)
self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim)
def call(self, patch):
batch_size = tf.shape(patch)[0]
cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim])
patch = self.projection(patch)
patch = tf.concat([cls_token, patch], axis=1) # Add CLS token
positions = tf.range(start=0, limit=num_patches + 1, delta=1)
return patch + self.position_embedding(positions)
# MLP Block
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
x = layers.LayerNormalization()(x) # Added LayerNorm
return x
# Vision Transformer Model
def create_vit_classifier():
inputs = layers.Input(shape=input_shape)
augmented = data_augmentation(inputs)
patches = Patches(patch_size)(augmented)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers):
x1 = layers.LayerNormalization()(encoded_patches)
attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1)
x2 = layers.Add()([attention_output, encoded_patches])
x3 = layers.LayerNormalization()(x2)
x3 = mlp(x3, transformer_units, dropout_rate)
encoded_patches = layers.Add()([x3, x2])
representation = layers.LayerNormalization()(encoded_patches)
cls_output = representation[:, 0, :] # Extract CLS token representation
features = mlp(cls_output, mlp_head_units, dropout_rate)
outputs = layers.Dense(num_classes, activation="softmax")(features)
return keras.Model(inputs, outputs)
# Cosine Decay Learning Rate Scheduler
def cosine_decay_schedule(initial_lr, total_steps):
return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps)
# Optimizer with Weight Decay & Cosine Decay
learning_rate = 3e-4
weight_decay = 0.03
num_epochs = 50
optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay)
# Compile Model
vit_classifier = create_vit_classifier()
vit_classifier.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
# Print Model Summary
vit_classifier.summary()