ErnestBeckham commited on
Commit
323c575
·
verified ·
1 Parent(s): 914808d
Files changed (1) hide show
  1. vit.py +15 -17
vit.py CHANGED
@@ -10,7 +10,7 @@ class ClassToken(layers.Layer):
10
  #initial values for the weight
11
  w_init = tf.random_normal_initializer()
12
  self.w = tf.Variable(
13
- initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
14
  trainable = True
15
  )
16
 
@@ -22,7 +22,8 @@ class ClassToken(layers.Layer):
22
  cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
23
  #change data type
24
  cls = tf.cast(cls, dtype=inputs.dtype)
25
- return cls
 
26
 
27
  def mlp(x, cf):
28
  x = layers.Dense(cf['mlp_dim'], activation='gelu')(x)
@@ -31,20 +32,20 @@ def mlp(x, cf):
31
  x = layers.Dropout(cf['dropout_rate'])(x)
32
  return x
33
 
34
-
35
  def transformer_encoder(x, cf):
36
  skip_1 = x
37
  x = layers.LayerNormalization()(x)
38
  x = layers.MultiHeadAttention(num_heads=cf['num_heads'], key_dim=cf['hidden_dim'])(x,x)
39
  x = layers.Add()([x, skip_1])
40
-
41
  skip_2 = x
42
  x = layers.LayerNormalization()(x)
43
  x = mlp(x, cf)
44
  x = layers.Add()([x, skip_2])
45
-
46
  return x
47
 
 
48
  def resnet_block(x, filters, strides=1):
49
  identity = x
50
 
@@ -63,13 +64,14 @@ def resnet_block(x, filters, strides=1):
63
  x = layers.Activation('relu')(x)
64
  return x
65
 
 
66
  def build_resnet(input_shape):
67
 
68
  x = layers.Conv2D(32, kernel_size=7, strides=2, padding='same')(input_shape)
69
  x = layers.BatchNormalization()(x)
70
  x = layers.Activation('relu')(x)
71
  x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
72
-
73
  x = resnet_block(x, filters=32)
74
  x = resnet_block(x, filters=32)
75
 
@@ -78,13 +80,10 @@ def build_resnet(input_shape):
78
 
79
  x = resnet_block(x, filters=128, strides=2)
80
  x = resnet_block(x, filters=128)
81
-
82
  x = resnet_block(x, filters=256, strides=2)
83
  x = resnet_block(x, filters=256)
84
-
85
- x = resnet_block(x, filters=512, strides=2)
86
- x = resnet_block(x, filters=512)
87
-
88
  return x
89
 
90
 
@@ -108,21 +107,20 @@ def CNN_ViT(hp):
108
  print(f"position embeding : {pos_embed.shape}")
109
  #Patch + Position Embedding
110
  embed = patch_embed + pos_embed
111
-
112
  #Token
113
  token = ClassToken()(embed)
114
  x = layers.Concatenate(axis=1)([token, embed]) #(None, 257, 256)
115
-
116
  #Transformer encoder
117
  for _ in range(hp['num_layers']):
118
  x = transformer_encoder(x, hp)
119
-
120
-
121
  x = layers.LayerNormalization()(x)
122
  x = x[:, 0, :]
123
  x = layers.Dense(hp['num_classes'], activation='softmax')(x)
124
-
125
  model = Model(inputs, x)
126
-
127
  return model
128
 
 
10
  #initial values for the weight
11
  w_init = tf.random_normal_initializer()
12
  self.w = tf.Variable(
13
+ initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
14
  trainable = True
15
  )
16
 
 
22
  cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
23
  #change data type
24
  cls = tf.cast(cls, dtype=inputs.dtype)
25
+ return cls
26
+
27
 
28
  def mlp(x, cf):
29
  x = layers.Dense(cf['mlp_dim'], activation='gelu')(x)
 
32
  x = layers.Dropout(cf['dropout_rate'])(x)
33
  return x
34
 
 
35
  def transformer_encoder(x, cf):
36
  skip_1 = x
37
  x = layers.LayerNormalization()(x)
38
  x = layers.MultiHeadAttention(num_heads=cf['num_heads'], key_dim=cf['hidden_dim'])(x,x)
39
  x = layers.Add()([x, skip_1])
40
+
41
  skip_2 = x
42
  x = layers.LayerNormalization()(x)
43
  x = mlp(x, cf)
44
  x = layers.Add()([x, skip_2])
45
+
46
  return x
47
 
48
+
49
  def resnet_block(x, filters, strides=1):
50
  identity = x
51
 
 
64
  x = layers.Activation('relu')(x)
65
  return x
66
 
67
+
68
  def build_resnet(input_shape):
69
 
70
  x = layers.Conv2D(32, kernel_size=7, strides=2, padding='same')(input_shape)
71
  x = layers.BatchNormalization()(x)
72
  x = layers.Activation('relu')(x)
73
  x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
74
+
75
  x = resnet_block(x, filters=32)
76
  x = resnet_block(x, filters=32)
77
 
 
80
 
81
  x = resnet_block(x, filters=128, strides=2)
82
  x = resnet_block(x, filters=128)
83
+
84
  x = resnet_block(x, filters=256, strides=2)
85
  x = resnet_block(x, filters=256)
86
+
 
 
 
87
  return x
88
 
89
 
 
107
  print(f"position embeding : {pos_embed.shape}")
108
  #Patch + Position Embedding
109
  embed = patch_embed + pos_embed
110
+
111
  #Token
112
  token = ClassToken()(embed)
113
  x = layers.Concatenate(axis=1)([token, embed]) #(None, 257, 256)
114
+
115
  #Transformer encoder
116
  for _ in range(hp['num_layers']):
117
  x = transformer_encoder(x, hp)
118
+
119
+
120
  x = layers.LayerNormalization()(x)
121
  x = x[:, 0, :]
122
  x = layers.Dense(hp['num_classes'], activation='softmax')(x)
123
+
124
  model = Model(inputs, x)
 
125
  return model
126