ErnestBeckham commited on
Commit
d80c331
·
verified ·
1 Parent(s): 0b75f11

model created

Browse files
Files changed (1) hide show
  1. vit.py +127 -0
vit.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+ class ClassToken(layers.Layer):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def build(self, input_shape):
9
+ #initial values for the weight
10
+ w_init = tf.random_normal_initializer()
11
+ self.w = tf.Variable(
12
+ initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
13
+ trainable = True
14
+ )
15
+
16
+ def call(self, inputs):
17
+ batch_size = tf.shape(inputs)[0]
18
+ hidden_dim = self.w.shape[-1]
19
+
20
+ #reshape
21
+ cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
22
+ #change data type
23
+ cls = tf.cast(cls, dtype=inputs.dtype)
24
+ return cls
25
+
26
+ def mlp(x, cf):
27
+ x = layers.Dense(cf['mlp_dim'], activation='gelu')(x)
28
+ x = layers.Dropout(cf['dropout_rate'])(x)
29
+ x = layers.Dense(cf['hidden_dim'])(x)
30
+ x = layers.Dropout(cf['dropout_rate'])(x)
31
+ return x
32
+
33
+
34
+ def transformer_encoder(x, cf):
35
+ skip_1 = x
36
+ x = layers.LayerNormalization()(x)
37
+ x = layers.MultiHeadAttention(num_heads=cf['num_heads'], key_dim=cf['hidden_dim'])(x,x)
38
+ x = layers.Add()([x, skip_1])
39
+
40
+ skip_2 = x
41
+ x = layers.LayerNormalization()(x)
42
+ x = mlp(x, cf)
43
+ x = layers.Add()([x, skip_2])
44
+
45
+ return x
46
+
47
+ def resnet_block(x, filters, strides=1):
48
+ identity = x
49
+
50
+ x = layers.Conv2D(filters, kernel_size=5, strides=strides, padding='same')(x)
51
+ x = layers.BatchNormalization()(x)
52
+ x = layers.Activation('relu')(x)
53
+
54
+ x = layers.Conv2D(filters, kernel_size=5, strides=1, padding='same')(x)
55
+ x = layers.BatchNormalization()(x)
56
+
57
+ if strides > 1:
58
+ identity = layers.Conv2D(filters, kernel_size=1, strides=strides, padding='same')(identity)
59
+ identity = layers.BatchNormalization()(identity)
60
+
61
+ x = layers.Add()([x, identity])
62
+ x = layers.Activation('relu')(x)
63
+ return x
64
+
65
+ def build_resnet(input_shape):
66
+
67
+ x = layers.Conv2D(32, kernel_size=7, strides=2, padding='same')(input_shape)
68
+ x = layers.BatchNormalization()(x)
69
+ x = layers.Activation('relu')(x)
70
+ x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
71
+
72
+ x = resnet_block(x, filters=32)
73
+ x = resnet_block(x, filters=32)
74
+
75
+ x = resnet_block(x, filters=64, strides=2)
76
+ x = resnet_block(x, filters=64)
77
+
78
+ x = resnet_block(x, filters=128, strides=2)
79
+ x = resnet_block(x, filters=128)
80
+
81
+ x = resnet_block(x, filters=256, strides=2)
82
+ x = resnet_block(x, filters=256)
83
+
84
+ x = resnet_block(x, filters=512, strides=2)
85
+ x = resnet_block(x, filters=512)
86
+
87
+ return x
88
+
89
+
90
+ def CNN_ViT(hp):
91
+ input_shape = (hp['image_size'], hp['image_size'], hp['num_channels'])
92
+ inputs = layers.Input(input_shape)
93
+ print(inputs.shape)
94
+ output = build_resnet(inputs)
95
+ print(output.shape)
96
+
97
+ patch_embed = layers.Conv2D(hp['hidden_dim'], kernel_size=(hp['patch_size']), padding='same')(output)
98
+ print(patch_embed.shape)
99
+ _, h, w, f = output.shape
100
+ patch_embed = layers.Reshape((h*w,f))(output)
101
+
102
+ #Position Embedding
103
+ positions = tf.range(start=0, limit=hp['num_patches'], delta=1)
104
+ pos_embed = layers.Embedding(input_dim=hp['num_patches'], output_dim=hp['hidden_dim'])(positions)
105
+
106
+ print(f"patch embedding : {patch_embed.shape}")
107
+ print(f"position embeding : {pos_embed.shape}")
108
+ #Patch + Position Embedding
109
+ embed = patch_embed + pos_embed
110
+
111
+ #Token
112
+ token = ClassToken()(embed)
113
+ x = layers.Concatenate(axis=1)([token, embed]) #(None, 257, 256)
114
+
115
+ #Transformer encoder
116
+ for _ in range(hp['num_layers']):
117
+ x = transformer_encoder(x, hp)
118
+
119
+
120
+ x = layers.LayerNormalization()(x)
121
+ x = x[:, 0, :]
122
+ x = layers.Dense(hp['num_classes'], activation='softmax')(x)
123
+
124
+ model = Model(inputs, x)
125
+
126
+ return model
127
+