Rename vit.py to vit_model.py
Browse files- vit.py +0 -106
- vit_model.py +53 -0
vit.py
DELETED
@@ -1,106 +0,0 @@
|
|
1 |
-
import tensorflow as tf
|
2 |
-
from tensorflow import keras
|
3 |
-
from tensorflow.keras import layers
|
4 |
-
|
5 |
-
# Model hyperparameters
|
6 |
-
num_classes = 3
|
7 |
-
input_shape = (256, 256, 3)
|
8 |
-
image_size = 256
|
9 |
-
patch_size = 16
|
10 |
-
num_patches = (image_size // patch_size) ** 2
|
11 |
-
projection_dim = 64
|
12 |
-
num_heads = 8 # Increased from 4 → 8
|
13 |
-
transformer_units = [projection_dim * 2, projection_dim]
|
14 |
-
transformer_layers = 12 # Increased from 8 → 12
|
15 |
-
mlp_head_units = [2048, 1024]
|
16 |
-
dropout_rate = 0.1
|
17 |
-
|
18 |
-
# Data Augmentation with stronger transformations
|
19 |
-
data_augmentation = keras.Sequential(
|
20 |
-
[
|
21 |
-
layers.Normalization(),
|
22 |
-
layers.Resizing(image_size, image_size),
|
23 |
-
layers.RandomFlip("horizontal"),
|
24 |
-
layers.RandomRotation(factor=0.1),
|
25 |
-
layers.RandomZoom(0.2, 0.2),
|
26 |
-
],
|
27 |
-
name="data_augmentation",
|
28 |
-
)
|
29 |
-
|
30 |
-
# Patch Creation Layer
|
31 |
-
class Patches(layers.Layer):
|
32 |
-
def __init__(self, patch_size):
|
33 |
-
super().__init__()
|
34 |
-
self.patch_size = patch_size
|
35 |
-
|
36 |
-
def call(self, images):
|
37 |
-
batch_size = tf.shape(images)[0]
|
38 |
-
patches = tf.image.extract_patches(
|
39 |
-
images=images,
|
40 |
-
sizes=[1, self.patch_size, self.patch_size, 1],
|
41 |
-
strides=[1, self.patch_size, self.patch_size, 1],
|
42 |
-
rates=[1, 1, 1, 1],
|
43 |
-
padding="VALID",
|
44 |
-
)
|
45 |
-
patch_dims = patches.shape[-1]
|
46 |
-
return tf.reshape(patches, [batch_size, -1, patch_dims])
|
47 |
-
|
48 |
-
# Patch Encoding with Learnable [CLS] Token
|
49 |
-
class PatchEncoder(layers.Layer):
|
50 |
-
def __init__(self, num_patches, projection_dim):
|
51 |
-
super().__init__()
|
52 |
-
self.projection = layers.Dense(projection_dim)
|
53 |
-
self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True)
|
54 |
-
self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim)
|
55 |
-
|
56 |
-
def call(self, patch):
|
57 |
-
batch_size = tf.shape(patch)[0]
|
58 |
-
cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim])
|
59 |
-
patch = self.projection(patch)
|
60 |
-
patch = tf.concat([cls_token, patch], axis=1) # Add CLS token
|
61 |
-
positions = tf.range(start=0, limit=num_patches + 1, delta=1)
|
62 |
-
return patch + self.position_embedding(positions)
|
63 |
-
|
64 |
-
# MLP Block
|
65 |
-
def mlp(x, hidden_units, dropout_rate):
|
66 |
-
for units in hidden_units:
|
67 |
-
x = layers.Dense(units, activation=tf.nn.gelu)(x)
|
68 |
-
x = layers.Dropout(dropout_rate)(x)
|
69 |
-
x = layers.LayerNormalization()(x) # Added LayerNorm
|
70 |
-
return x
|
71 |
-
|
72 |
-
# Vision Transformer Model
|
73 |
-
def create_vit_classifier():
|
74 |
-
inputs = layers.Input(shape=input_shape)
|
75 |
-
augmented = data_augmentation(inputs)
|
76 |
-
patches = Patches(patch_size)(augmented)
|
77 |
-
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
|
78 |
-
|
79 |
-
for _ in range(transformer_layers):
|
80 |
-
x1 = layers.LayerNormalization()(encoded_patches)
|
81 |
-
attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1)
|
82 |
-
x2 = layers.Add()([attention_output, encoded_patches])
|
83 |
-
x3 = layers.LayerNormalization()(x2)
|
84 |
-
x3 = mlp(x3, transformer_units, dropout_rate)
|
85 |
-
encoded_patches = layers.Add()([x3, x2])
|
86 |
-
|
87 |
-
representation = layers.LayerNormalization()(encoded_patches)
|
88 |
-
cls_output = representation[:, 0, :] # Extract CLS token representation
|
89 |
-
features = mlp(cls_output, mlp_head_units, dropout_rate)
|
90 |
-
outputs = layers.Dense(num_classes, activation="softmax")(features)
|
91 |
-
|
92 |
-
return keras.Model(inputs, outputs)
|
93 |
-
|
94 |
-
# Cosine Decay Learning Rate Scheduler
|
95 |
-
def cosine_decay_schedule(initial_lr, total_steps):
|
96 |
-
return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps)
|
97 |
-
|
98 |
-
# Optimizer with Weight Decay & Cosine Decay
|
99 |
-
learning_rate = 3e-4
|
100 |
-
weight_decay = 0.03
|
101 |
-
num_epochs = 50
|
102 |
-
|
103 |
-
optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay)
|
104 |
-
|
105 |
-
# Compile Model
|
106 |
-
vit_classifier = create_vit_classifier()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vit_model.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class PatchEmbedding(nn.Module):
|
6 |
+
def __init__(self, img_size=128, patch_size=8, in_channels=3, embed_dim=768):
|
7 |
+
super().__init__()
|
8 |
+
self.patch_size = patch_size
|
9 |
+
self.num_patches = (img_size // patch_size) ** 2
|
10 |
+
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
11 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
12 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
B = x.shape[0]
|
16 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
|
17 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, embed_dim]
|
18 |
+
x = torch.cat([cls_tokens, x], dim=1) # Add CLS token
|
19 |
+
x += self.pos_embedding
|
20 |
+
return x
|
21 |
+
|
22 |
+
class TransformerBlock(nn.Module):
|
23 |
+
def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
|
24 |
+
super().__init__()
|
25 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
26 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
|
27 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
28 |
+
self.mlp = nn.Sequential(
|
29 |
+
nn.Linear(embed_dim, mlp_dim),
|
30 |
+
nn.GELU(),
|
31 |
+
nn.Dropout(dropout),
|
32 |
+
nn.Linear(mlp_dim, embed_dim),
|
33 |
+
nn.Dropout(dropout),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] # Pre-LN
|
38 |
+
x = x + self.mlp(self.norm2(x))
|
39 |
+
return x
|
40 |
+
|
41 |
+
class VisionTransformer(nn.Module):
|
42 |
+
def __init__(self, img_size=128, patch_size=8, num_classes=10, embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1):
|
43 |
+
super().__init__()
|
44 |
+
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels=3, embed_dim=embed_dim)
|
45 |
+
self.transformer = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(depth)])
|
46 |
+
self.norm = nn.LayerNorm(embed_dim)
|
47 |
+
self.head = nn.Linear(embed_dim, num_classes)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = self.patch_embed(x)
|
51 |
+
x = self.transformer(x)
|
52 |
+
x = self.norm(x[:, 0]) # CLS token output
|
53 |
+
return self.head(x)
|