Sobit commited on
Commit
5a9b64e
·
verified ·
1 Parent(s): 3be8c20

Rename vit.py to vit_model.py

Browse files
Files changed (2) hide show
  1. vit.py +0 -106
  2. vit_model.py +53 -0
vit.py DELETED
@@ -1,106 +0,0 @@
1
- import tensorflow as tf
2
- from tensorflow import keras
3
- from tensorflow.keras import layers
4
-
5
- # Model hyperparameters
6
- num_classes = 3
7
- input_shape = (256, 256, 3)
8
- image_size = 256
9
- patch_size = 16
10
- num_patches = (image_size // patch_size) ** 2
11
- projection_dim = 64
12
- num_heads = 8 # Increased from 4 → 8
13
- transformer_units = [projection_dim * 2, projection_dim]
14
- transformer_layers = 12 # Increased from 8 → 12
15
- mlp_head_units = [2048, 1024]
16
- dropout_rate = 0.1
17
-
18
- # Data Augmentation with stronger transformations
19
- data_augmentation = keras.Sequential(
20
- [
21
- layers.Normalization(),
22
- layers.Resizing(image_size, image_size),
23
- layers.RandomFlip("horizontal"),
24
- layers.RandomRotation(factor=0.1),
25
- layers.RandomZoom(0.2, 0.2),
26
- ],
27
- name="data_augmentation",
28
- )
29
-
30
- # Patch Creation Layer
31
- class Patches(layers.Layer):
32
- def __init__(self, patch_size):
33
- super().__init__()
34
- self.patch_size = patch_size
35
-
36
- def call(self, images):
37
- batch_size = tf.shape(images)[0]
38
- patches = tf.image.extract_patches(
39
- images=images,
40
- sizes=[1, self.patch_size, self.patch_size, 1],
41
- strides=[1, self.patch_size, self.patch_size, 1],
42
- rates=[1, 1, 1, 1],
43
- padding="VALID",
44
- )
45
- patch_dims = patches.shape[-1]
46
- return tf.reshape(patches, [batch_size, -1, patch_dims])
47
-
48
- # Patch Encoding with Learnable [CLS] Token
49
- class PatchEncoder(layers.Layer):
50
- def __init__(self, num_patches, projection_dim):
51
- super().__init__()
52
- self.projection = layers.Dense(projection_dim)
53
- self.cls_token = self.add_weight(shape=(1, 1, projection_dim), initializer="zeros", trainable=True)
54
- self.position_embedding = layers.Embedding(input_dim=num_patches + 1, output_dim=projection_dim)
55
-
56
- def call(self, patch):
57
- batch_size = tf.shape(patch)[0]
58
- cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, projection_dim])
59
- patch = self.projection(patch)
60
- patch = tf.concat([cls_token, patch], axis=1) # Add CLS token
61
- positions = tf.range(start=0, limit=num_patches + 1, delta=1)
62
- return patch + self.position_embedding(positions)
63
-
64
- # MLP Block
65
- def mlp(x, hidden_units, dropout_rate):
66
- for units in hidden_units:
67
- x = layers.Dense(units, activation=tf.nn.gelu)(x)
68
- x = layers.Dropout(dropout_rate)(x)
69
- x = layers.LayerNormalization()(x) # Added LayerNorm
70
- return x
71
-
72
- # Vision Transformer Model
73
- def create_vit_classifier():
74
- inputs = layers.Input(shape=input_shape)
75
- augmented = data_augmentation(inputs)
76
- patches = Patches(patch_size)(augmented)
77
- encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
78
-
79
- for _ in range(transformer_layers):
80
- x1 = layers.LayerNormalization()(encoded_patches)
81
- attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)(x1, x1)
82
- x2 = layers.Add()([attention_output, encoded_patches])
83
- x3 = layers.LayerNormalization()(x2)
84
- x3 = mlp(x3, transformer_units, dropout_rate)
85
- encoded_patches = layers.Add()([x3, x2])
86
-
87
- representation = layers.LayerNormalization()(encoded_patches)
88
- cls_output = representation[:, 0, :] # Extract CLS token representation
89
- features = mlp(cls_output, mlp_head_units, dropout_rate)
90
- outputs = layers.Dense(num_classes, activation="softmax")(features)
91
-
92
- return keras.Model(inputs, outputs)
93
-
94
- # Cosine Decay Learning Rate Scheduler
95
- def cosine_decay_schedule(initial_lr, total_steps):
96
- return keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_lr, decay_steps=total_steps)
97
-
98
- # Optimizer with Weight Decay & Cosine Decay
99
- learning_rate = 3e-4
100
- weight_decay = 0.03
101
- num_epochs = 50
102
-
103
- optimizer = keras.optimizers.AdamW(learning_rate=cosine_decay_schedule(learning_rate, num_epochs * 100), weight_decay=weight_decay)
104
-
105
- # Compile Model
106
- vit_classifier = create_vit_classifier()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vit_model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class PatchEmbedding(nn.Module):
6
+ def __init__(self, img_size=128, patch_size=8, in_channels=3, embed_dim=768):
7
+ super().__init__()
8
+ self.patch_size = patch_size
9
+ self.num_patches = (img_size // patch_size) ** 2
10
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
11
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
12
+ self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim))
13
+
14
+ def forward(self, x):
15
+ B = x.shape[0]
16
+ x = self.proj(x).flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
17
+ cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, embed_dim]
18
+ x = torch.cat([cls_tokens, x], dim=1) # Add CLS token
19
+ x += self.pos_embedding
20
+ return x
21
+
22
+ class TransformerBlock(nn.Module):
23
+ def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
24
+ super().__init__()
25
+ self.norm1 = nn.LayerNorm(embed_dim)
26
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
27
+ self.norm2 = nn.LayerNorm(embed_dim)
28
+ self.mlp = nn.Sequential(
29
+ nn.Linear(embed_dim, mlp_dim),
30
+ nn.GELU(),
31
+ nn.Dropout(dropout),
32
+ nn.Linear(mlp_dim, embed_dim),
33
+ nn.Dropout(dropout),
34
+ )
35
+
36
+ def forward(self, x):
37
+ x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] # Pre-LN
38
+ x = x + self.mlp(self.norm2(x))
39
+ return x
40
+
41
+ class VisionTransformer(nn.Module):
42
+ def __init__(self, img_size=128, patch_size=8, num_classes=10, embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1):
43
+ super().__init__()
44
+ self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels=3, embed_dim=embed_dim)
45
+ self.transformer = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(depth)])
46
+ self.norm = nn.LayerNorm(embed_dim)
47
+ self.head = nn.Linear(embed_dim, num_classes)
48
+
49
+ def forward(self, x):
50
+ x = self.patch_embed(x)
51
+ x = self.transformer(x)
52
+ x = self.norm(x[:, 0]) # CLS token output
53
+ return self.head(x)