Sobit commited on
Commit
fc729ee
·
verified ·
1 Parent(s): 9a2c205

Delete models/vit.py

Browse files
Files changed (1) hide show
  1. models/vit.py +0 -110
models/vit.py DELETED
@@ -1,110 +0,0 @@
1
- import tensorflow as tf
2
- from tensorflow import keras
3
- from tensorflow.keras import layers
4
-
5
- # Model hyperparameters
6
- num_classes = 3
7
- input_shape = (256, 256, 3)
8
- image_size = 256
9
- patch_size = 16
10
- num_patches = (image_size // patch_size) ** 2
11
- projection_dim = 64
12
- num_heads = 8 # Increased from 4 → 8
13
- transformer_units = [projection_dim * 2, projection_dim]
14
- transformer_layers = 12 # Increased from 8 → 12
15
- mlp_head_units = [2048, 1024]
16
- dropout_rate = 0.1
17
-
18
- # Data Augmentation with stronger transformations
19
- data_augmentation = keras.Sequential(
20
- [
21
- layers.Normalization(),
22
- layers.Resizing(image_size, image_size),
23
- layers.RandomFlip("horizontal"),
24
- layers.RandomRotation(factor=0.1),
25
- layers.RandomZoom(0.2, 0.2),
26
- ],
27
- name="data_augmentation",
28
- )
29
-
30
- # Patch Creation Layer
31
- class Patches(layers.Layer):
32
- def __init__(self, patch_size):
33
- super().__init__()
34
- self.patch_size = patch_size
35
-
36
- def call(self, images):
37
- batch_size = tf.shape(images)[0]
38
- patches = tf.image.extract_patches(
39
- images=images,
40
- sizes=[1, self.patch_size, self.patch_size, 1],
41
- strides=[1, self.patch_size, self.patch_size, 1],
42
- rates=[1, 1, 1, 1],
43
- padding="VALID",
44
- )
45
- patch_dims = patches.shape[-1]
46
- return tf.reshape(patches, [batch_size, -1, patch_dims])
47
-
48
- # Patch Encoding with Learnable [CLS] Token
49
- class PatchEncoder(layers.Layer):
50
- def __init__(self, num_patches, projection_dim):
51
- super().__init__()
52
- self.projection = layers.Dense(projection_dim)
53
- self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True)
54
- self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim)
55
-
56
- def call(self, patch):
57
- batch_size = tf.shape(patch)[0]
58
- cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim])
59
- patch = self.projection(patch)
60
- patch = tf.concat([cls_token, patch], axis=1) # Add CLS token
61
- positions = tf.range(start=0, limit=num_patches + 1, delta=1)
62
- return patch + self.position_embedding(positions)
63
-
64
- # MLP Block
65
- def mlp(x, hidden_units, dropout_rate):
66
- for units in hidden_units:
67
- x = layers.Dense(units, activation=tf.nn.gelu)(x)
68
- x = layers.Dropout(dropout_rate)(x)
69
- x = layers.LayerNormalization()(x) # Added LayerNorm
70
- return x
71
-
72
- # Vision Transformer Model
73
- def create_vit_classifier():
74
- inputs = layers.Input(shape=input_shape)
75
- augmented = data_augmentation(inputs)
76
- patches = Patches(patch_size)(augmented)
77
- encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
78
-
79
- for _ in range(transformer_layers):
80
- x1 = layers.LayerNormalization()(encoded_patches)
81
- attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1)
82
- x2 = layers.Add()([attention_output, encoded_patches])
83
- x3 = layers.LayerNormalization()(x2)
84
- x3 = mlp(x3, transformer_units, dropout_rate)
85
- encoded_patches = layers.Add()([x3, x2])
86
-
87
- representation = layers.LayerNormalization()(encoded_patches)
88
- cls_output = representation[:, 0, :] # Extract CLS token representation
89
- features = mlp(cls_output, mlp_head_units, dropout_rate)
90
- outputs = layers.Dense(num_classes, activation="softmax")(features)
91
-
92
- return keras.Model(inputs, outputs)
93
-
94
- # Cosine Decay Learning Rate Scheduler
95
- def cosine_decay_schedule(initial_lr, total_steps):
96
- return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps)
97
-
98
- # Optimizer with Weight Decay & Cosine Decay
99
- learning_rate = 3e-4
100
- weight_decay = 0.03
101
- num_epochs = 50
102
-
103
- optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay)
104
-
105
- # Compile Model
106
- vit_classifier = create_vit_classifier()
107
- vit_classifier.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
108
-
109
- # Print Model Summary
110
- vit_classifier.summary()