|  | import tensorflow as tf | 
					
						
						|  | from tensorflow import keras | 
					
						
						|  | from tensorflow.keras import layers | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_classes = 3 | 
					
						
						|  | input_shape = (256, 256, 3) | 
					
						
						|  | image_size = 256 | 
					
						
						|  | patch_size = 16 | 
					
						
						|  | num_patches = (image_size // patch_size) ** 2 | 
					
						
						|  | projection_dim = 64 | 
					
						
						|  | num_heads = 8 | 
					
						
						|  | transformer_units = [projection_dim * 2, projection_dim] | 
					
						
						|  | transformer_layers = 12 | 
					
						
						|  | mlp_head_units = [2048, 1024] | 
					
						
						|  | dropout_rate = 0.1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | data_augmentation = keras.Sequential( | 
					
						
						|  | [ | 
					
						
						|  | layers.Normalization(), | 
					
						
						|  | layers.Resizing(image_size, image_size), | 
					
						
						|  | layers.RandomFlip("horizontal"), | 
					
						
						|  | layers.RandomRotation(factor=0.1), | 
					
						
						|  | layers.RandomZoom(0.2, 0.2), | 
					
						
						|  | ], | 
					
						
						|  | name="data_augmentation", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Patches(layers.Layer): | 
					
						
						|  | def __init__(self, patch_size): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.patch_size = patch_size | 
					
						
						|  |  | 
					
						
						|  | def call(self, images): | 
					
						
						|  | batch_size = tf.shape(images)[0] | 
					
						
						|  | patches = tf.image.extract_patches( | 
					
						
						|  | images=images, | 
					
						
						|  | sizes=[1, self.patch_size, self.patch_size, 1], | 
					
						
						|  | strides=[1, self.patch_size, self.patch_size, 1], | 
					
						
						|  | rates=[1, 1, 1, 1], | 
					
						
						|  | padding="VALID", | 
					
						
						|  | ) | 
					
						
						|  | patch_dims = patches.shape[-1] | 
					
						
						|  | return tf.reshape(patches, [batch_size, -1, patch_dims]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PatchEncoder(layers.Layer): | 
					
						
						|  | def __init__(self, num_patches, projection_dim): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.projection = layers.Dense(projection_dim) | 
					
						
						|  | self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True) | 
					
						
						|  | self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim) | 
					
						
						|  |  | 
					
						
						|  | def call(self, patch): | 
					
						
						|  | batch_size = tf.shape(patch)[0] | 
					
						
						|  | cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim]) | 
					
						
						|  | patch = self.projection(patch) | 
					
						
						|  | patch = tf.concat([cls_token, patch], axis=1) | 
					
						
						|  | positions = tf.range(start=0, limit=num_patches + 1, delta=1) | 
					
						
						|  | return patch + self.position_embedding(positions) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def mlp(x, hidden_units, dropout_rate): | 
					
						
						|  | for units in hidden_units: | 
					
						
						|  | x = layers.Dense(units, activation=tf.nn.gelu)(x) | 
					
						
						|  | x = layers.Dropout(dropout_rate)(x) | 
					
						
						|  | x = layers.LayerNormalization()(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_vit_classifier(): | 
					
						
						|  | inputs = layers.Input(shape=input_shape) | 
					
						
						|  | augmented = data_augmentation(inputs) | 
					
						
						|  | patches = Patches(patch_size)(augmented) | 
					
						
						|  | encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) | 
					
						
						|  |  | 
					
						
						|  | for _ in range(transformer_layers): | 
					
						
						|  | x1 = layers.LayerNormalization()(encoded_patches) | 
					
						
						|  | attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1) | 
					
						
						|  | x2 = layers.Add()([attention_output, encoded_patches]) | 
					
						
						|  | x3 = layers.LayerNormalization()(x2) | 
					
						
						|  | x3 = mlp(x3, transformer_units, dropout_rate) | 
					
						
						|  | encoded_patches = layers.Add()([x3, x2]) | 
					
						
						|  |  | 
					
						
						|  | representation = layers.LayerNormalization()(encoded_patches) | 
					
						
						|  | cls_output = representation[:, 0, :] | 
					
						
						|  | features = mlp(cls_output, mlp_head_units, dropout_rate) | 
					
						
						|  | outputs = layers.Dense(num_classes, activation="softmax")(features) | 
					
						
						|  |  | 
					
						
						|  | return keras.Model(inputs, outputs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cosine_decay_schedule(initial_lr, total_steps): | 
					
						
						|  | return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | learning_rate = 3e-4 | 
					
						
						|  | weight_decay = 0.03 | 
					
						
						|  | num_epochs = 50 | 
					
						
						|  |  | 
					
						
						|  | optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vit_classifier = create_vit_classifier() | 
					
						
						|  | vit_classifier.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vit_classifier.summary() | 
					
						
						|  |  |