Sobit commited on
Commit
045c942
·
verified ·
1 Parent(s): f16aaf6

Create vit.py

Browse files
Files changed (1) hide show
  1. models/vit.py +110 -0
models/vit.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ from tensorflow.keras import layers
4
+
5
+ # Model hyperparameters
6
+ num_classes = 3
7
+ input_shape = (256, 256, 3)
8
+ image_size = 256
9
+ patch_size = 16
10
+ num_patches = (image_size // patch_size) ** 2
11
+ projection_dim = 64
12
+ num_heads = 8 # Increased from 4 → 8
13
+ transformer_units = [projection_dim * 2, projection_dim]
14
+ transformer_layers = 12 # Increased from 8 → 12
15
+ mlp_head_units = [2048, 1024]
16
+ dropout_rate = 0.1
17
+
18
+ # Data Augmentation with stronger transformations
19
+ data_augmentation = keras.Sequential(
20
+ [
21
+ layers.Normalization(),
22
+ layers.Resizing(image_size, image_size),
23
+ layers.RandomFlip("horizontal"),
24
+ layers.RandomRotation(factor=0.1),
25
+ layers.RandomZoom(0.2, 0.2),
26
+ ],
27
+ name="data_augmentation",
28
+ )
29
+
30
+ # Patch Creation Layer
31
+ class Patches(layers.Layer):
32
+ def __init__(self, patch_size):
33
+ super().__init__()
34
+ self.patch_size = patch_size
35
+
36
+ def call(self, images):
37
+ batch_size = tf.shape(images)[0]
38
+ patches = tf.image.extract_patches(
39
+ images=images,
40
+ sizes=[1, self.patch_size, self.patch_size, 1],
41
+ strides=[1, self.patch_size, self.patch_size, 1],
42
+ rates=[1, 1, 1, 1],
43
+ padding="VALID",
44
+ )
45
+ patch_dims = patches.shape[-1]
46
+ return tf.reshape(patches, [batch_size, -1, patch_dims])
47
+
48
+ # Patch Encoding with Learnable [CLS] Token
49
+ class PatchEncoder(layers.Layer):
50
+ def __init__(self, num_patches, projection_dim):
51
+ super().__init__()
52
+ self.projection = layers.Dense(projection_dim)
53
+ self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True)
54
+ self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim)
55
+
56
+ def call(self, patch):
57
+ batch_size = tf.shape(patch)[0]
58
+ cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim])
59
+ patch = self.projection(patch)
60
+ patch = tf.concat([cls_token, patch], axis=1) # Add CLS token
61
+ positions = tf.range(start=0, limit=num_patches + 1, delta=1)
62
+ return patch + self.position_embedding(positions)
63
+
64
+ # MLP Block
65
+ def mlp(x, hidden_units, dropout_rate):
66
+ for units in hidden_units:
67
+ x = layers.Dense(units, activation=tf.nn.gelu)(x)
68
+ x = layers.Dropout(dropout_rate)(x)
69
+ x = layers.LayerNormalization()(x) # Added LayerNorm
70
+ return x
71
+
72
+ # Vision Transformer Model
73
+ def create_vit_classifier():
74
+ inputs = layers.Input(shape=input_shape)
75
+ augmented = data_augmentation(inputs)
76
+ patches = Patches(patch_size)(augmented)
77
+ encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
78
+
79
+ for _ in range(transformer_layers):
80
+ x1 = layers.LayerNormalization()(encoded_patches)
81
+ attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1)
82
+ x2 = layers.Add()([attention_output, encoded_patches])
83
+ x3 = layers.LayerNormalization()(x2)
84
+ x3 = mlp(x3, transformer_units, dropout_rate)
85
+ encoded_patches = layers.Add()([x3, x2])
86
+
87
+ representation = layers.LayerNormalization()(encoded_patches)
88
+ cls_output = representation[:, 0, :] # Extract CLS token representation
89
+ features = mlp(cls_output, mlp_head_units, dropout_rate)
90
+ outputs = layers.Dense(num_classes, activation="softmax")(features)
91
+
92
+ return keras.Model(inputs, outputs)
93
+
94
+ # Cosine Decay Learning Rate Scheduler
95
+ def cosine_decay_schedule(initial_lr, total_steps):
96
+ return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps)
97
+
98
+ # Optimizer with Weight Decay & Cosine Decay
99
+ learning_rate = 3e-4
100
+ weight_decay = 0.03
101
+ num_epochs = 50
102
+
103
+ optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay)
104
+
105
+ # Compile Model
106
+ vit_classifier = create_vit_classifier()
107
+ vit_classifier.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
108
+
109
+ # Print Model Summary
110
+ vit_classifier.summary()