|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
|
|
|
|
num_classes = 3 |
|
input_shape = (256, 256, 3) |
|
image_size = 256 |
|
patch_size = 16 |
|
num_patches = (image_size // patch_size) ** 2 |
|
projection_dim = 64 |
|
num_heads = 8 |
|
transformer_units = [projection_dim * 2, projection_dim] |
|
transformer_layers = 12 |
|
mlp_head_units = [2048, 1024] |
|
dropout_rate = 0.1 |
|
|
|
|
|
data_augmentation = keras.Sequential( |
|
[ |
|
layers.Normalization(), |
|
layers.Resizing(image_size, image_size), |
|
layers.RandomFlip("horizontal"), |
|
layers.RandomRotation(factor=0.1), |
|
layers.RandomZoom(0.2, 0.2), |
|
], |
|
name="data_augmentation", |
|
) |
|
|
|
|
|
class Patches(layers.Layer): |
|
def __init__(self, patch_size): |
|
super().__init__() |
|
self.patch_size = patch_size |
|
|
|
def call(self, images): |
|
batch_size = tf.shape(images)[0] |
|
patches = tf.image.extract_patches( |
|
images=images, |
|
sizes=[1, self.patch_size, self.patch_size, 1], |
|
strides=[1, self.patch_size, self.patch_size, 1], |
|
rates=[1, 1, 1, 1], |
|
padding="VALID", |
|
) |
|
patch_dims = patches.shape[-1] |
|
return tf.reshape(patches, [batch_size, -1, patch_dims]) |
|
|
|
|
|
class PatchEncoder(layers.Layer): |
|
def __init__(self, num_patches, projection_dim): |
|
super().__init__() |
|
self.projection = layers.Dense(projection_dim) |
|
self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True) |
|
self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim) |
|
|
|
def call(self, patch): |
|
batch_size = tf.shape(patch)[0] |
|
cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim]) |
|
patch = self.projection(patch) |
|
patch = tf.concat([cls_token, patch], axis=1) |
|
positions = tf.range(start=0, limit=num_patches + 1, delta=1) |
|
return patch + self.position_embedding(positions) |
|
|
|
|
|
def mlp(x, hidden_units, dropout_rate): |
|
for units in hidden_units: |
|
x = layers.Dense(units, activation=tf.nn.gelu)(x) |
|
x = layers.Dropout(dropout_rate)(x) |
|
x = layers.LayerNormalization()(x) |
|
return x |
|
|
|
|
|
def create_vit_classifier(): |
|
inputs = layers.Input(shape=input_shape) |
|
augmented = data_augmentation(inputs) |
|
patches = Patches(patch_size)(augmented) |
|
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) |
|
|
|
for _ in range(transformer_layers): |
|
x1 = layers.LayerNormalization()(encoded_patches) |
|
attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1) |
|
x2 = layers.Add()([attention_output, encoded_patches]) |
|
x3 = layers.LayerNormalization()(x2) |
|
x3 = mlp(x3, transformer_units, dropout_rate) |
|
encoded_patches = layers.Add()([x3, x2]) |
|
|
|
representation = layers.LayerNormalization()(encoded_patches) |
|
cls_output = representation[:, 0, :] |
|
features = mlp(cls_output, mlp_head_units, dropout_rate) |
|
outputs = layers.Dense(num_classes, activation="softmax")(features) |
|
|
|
return keras.Model(inputs, outputs) |
|
|
|
|
|
def cosine_decay_schedule(initial_lr, total_steps): |
|
return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps) |
|
|
|
|
|
learning_rate = 3e-4 |
|
weight_decay = 0.03 |
|
num_epochs = 50 |
|
|
|
optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay) |
|
|
|
|
|
vit_classifier = create_vit_classifier() |
|
vit_classifier.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"]) |
|
|
|
|
|
vit_classifier.summary() |
|
|