Delete models/vit.py
Browse files- models/vit.py +0 -110
models/vit.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
from tensorflow import keras
|
3 |
-
from tensorflow.keras import layers
|
4 |
-
|
5 |
-
# Model hyperparameters
|
6 |
-
num_classes = 3
|
7 |
-
input_shape = (256, 256, 3)
|
8 |
-
image_size = 256
|
9 |
-
patch_size = 16
|
10 |
-
num_patches = (image_size // patch_size) ** 2
|
11 |
-
projection_dim = 64
|
12 |
-
num_heads = 8 # Increased from 4 → 8
|
13 |
-
transformer_units = [projection_dim * 2, projection_dim]
|
14 |
-
transformer_layers = 12 # Increased from 8 → 12
|
15 |
-
mlp_head_units = [2048, 1024]
|
16 |
-
dropout_rate = 0.1
|
17 |
-
|
18 |
-
# Data Augmentation with stronger transformations
|
19 |
-
data_augmentation = keras.Sequential(
|
20 |
-
[
|
21 |
-
layers.Normalization(),
|
22 |
-
layers.Resizing(image_size, image_size),
|
23 |
-
layers.RandomFlip("horizontal"),
|
24 |
-
layers.RandomRotation(factor=0.1),
|
25 |
-
layers.RandomZoom(0.2, 0.2),
|
26 |
-
],
|
27 |
-
name="data_augmentation",
|
28 |
-
)
|
29 |
-
|
30 |
-
# Patch Creation Layer
|
31 |
-
class Patches(layers.Layer):
|
32 |
-
def __init__(self, patch_size):
|
33 |
-
super().__init__()
|
34 |
-
self.patch_size = patch_size
|
35 |
-
|
36 |
-
def call(self, images):
|
37 |
-
batch_size = tf.shape(images)[0]
|
38 |
-
patches = tf.image.extract_patches(
|
39 |
-
images=images,
|
40 |
-
sizes=[1, self.patch_size, self.patch_size, 1],
|
41 |
-
strides=[1, self.patch_size, self.patch_size, 1],
|
42 |
-
rates=[1, 1, 1, 1],
|
43 |
-
padding="VALID",
|
44 |
-
)
|
45 |
-
patch_dims = patches.shape[-1]
|
46 |
-
return tf.reshape(patches, [batch_size, -1, patch_dims])
|
47 |
-
|
48 |
-
# Patch Encoding with Learnable [CLS] Token
|
49 |
-
class PatchEncoder(layers.Layer):
|
50 |
-
def __init__(self, num_patches, projection_dim):
|
51 |
-
super().__init__()
|
52 |
-
self.projection = layers.Dense(projection_dim)
|
53 |
-
self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True)
|
54 |
-
self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim)
|
55 |
-
|
56 |
-
def call(self, patch):
|
57 |
-
batch_size = tf.shape(patch)[0]
|
58 |
-
cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim])
|
59 |
-
patch = self.projection(patch)
|
60 |
-
patch = tf.concat([cls_token, patch], axis=1) # Add CLS token
|
61 |
-
positions = tf.range(start=0, limit=num_patches + 1, delta=1)
|
62 |
-
return patch + self.position_embedding(positions)
|
63 |
-
|
64 |
-
# MLP Block
|
65 |
-
def mlp(x, hidden_units, dropout_rate):
|
66 |
-
for units in hidden_units:
|
67 |
-
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
68 |
-
x = layers.Dropout(dropout_rate)(x)
|
69 |
-
x = layers.LayerNormalization()(x) # Added LayerNorm
|
70 |
-
return x
|
71 |
-
|
72 |
-
# Vision Transformer Model
|
73 |
-
def create_vit_classifier():
|
74 |
-
inputs = layers.Input(shape=input_shape)
|
75 |
-
augmented = data_augmentation(inputs)
|
76 |
-
patches = Patches(patch_size)(augmented)
|
77 |
-
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
78 |
-
|
79 |
-
for _ in range(transformer_layers):
|
80 |
-
x1 = layers.LayerNormalization()(encoded_patches)
|
81 |
-
attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1)
|
82 |
-
x2 = layers.Add()([attention_output, encoded_patches])
|
83 |
-
x3 = layers.LayerNormalization()(x2)
|
84 |
-
x3 = mlp(x3, transformer_units, dropout_rate)
|
85 |
-
encoded_patches = layers.Add()([x3, x2])
|
86 |
-
|
87 |
-
representation = layers.LayerNormalization()(encoded_patches)
|
88 |
-
cls_output = representation[:, 0, :] # Extract CLS token representation
|
89 |
-
features = mlp(cls_output, mlp_head_units, dropout_rate)
|
90 |
-
outputs = layers.Dense(num_classes, activation="softmax")(features)
|
91 |
-
|
92 |
-
return keras.Model(inputs, outputs)
|
93 |
-
|
94 |
-
# Cosine Decay Learning Rate Scheduler
|
95 |
-
def cosine_decay_schedule(initial_lr, total_steps):
|
96 |
-
return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps)
|
97 |
-
|
98 |
-
# Optimizer with Weight Decay & Cosine Decay
|
99 |
-
learning_rate = 3e-4
|
100 |
-
weight_decay = 0.03
|
101 |
-
num_epochs = 50
|
102 |
-
|
103 |
-
optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay)
|
104 |
-
|
105 |
-
# Compile Model
|
106 |
-
vit_classifier = create_vit_classifier()
|
107 |
-
vit_classifier.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
108 |
-
|
109 |
-
# Print Model Summary
|
110 |
-
vit_classifier.summary()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|