File size: 3,698 Bytes
d80c331 14ca9cd d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 323c575 d80c331 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
class ClassToken(layers.Layer):
def __init__(self):
super().__init__()
def build(self, input_shape):
#initial values for the weight
w_init = tf.random_normal_initializer()
self.w = tf.Variable(
initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
trainable = True
)
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
hidden_dim = self.w.shape[-1]
#reshape
cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
#change data type
cls = tf.cast(cls, dtype=inputs.dtype)
return cls
def mlp(x, cf):
x = layers.Dense(cf['mlp_dim'], activation='gelu')(x)
x = layers.Dropout(cf['dropout_rate'])(x)
x = layers.Dense(cf['hidden_dim'])(x)
x = layers.Dropout(cf['dropout_rate'])(x)
return x
def transformer_encoder(x, cf):
skip_1 = x
x = layers.LayerNormalization()(x)
x = layers.MultiHeadAttention(num_heads=cf['num_heads'], key_dim=cf['hidden_dim'])(x,x)
x = layers.Add()([x, skip_1])
skip_2 = x
x = layers.LayerNormalization()(x)
x = mlp(x, cf)
x = layers.Add()([x, skip_2])
return x
def resnet_block(x, filters, strides=1):
identity = x
x = layers.Conv2D(filters, kernel_size=5, strides=strides, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters, kernel_size=5, strides=1, padding='same')(x)
x = layers.BatchNormalization()(x)
if strides > 1:
identity = layers.Conv2D(filters, kernel_size=1, strides=strides, padding='same')(identity)
identity = layers.BatchNormalization()(identity)
x = layers.Add()([x, identity])
x = layers.Activation('relu')(x)
return x
def build_resnet(input_shape):
x = layers.Conv2D(32, kernel_size=7, strides=2, padding='same')(input_shape)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
x = resnet_block(x, filters=32)
x = resnet_block(x, filters=32)
x = resnet_block(x, filters=64, strides=2)
x = resnet_block(x, filters=64)
x = resnet_block(x, filters=128, strides=2)
x = resnet_block(x, filters=128)
x = resnet_block(x, filters=256, strides=2)
x = resnet_block(x, filters=256)
return x
def CNN_ViT(hp):
input_shape = (hp['image_size'], hp['image_size'], hp['num_channels'])
inputs = layers.Input(input_shape)
print(inputs.shape)
output = build_resnet(inputs)
print(output.shape)
patch_embed = layers.Conv2D(hp['hidden_dim'], kernel_size=(hp['patch_size']), padding='same')(output)
print(patch_embed.shape)
_, h, w, f = output.shape
patch_embed = layers.Reshape((h*w,f))(output)
#Position Embedding
positions = tf.range(start=0, limit=hp['num_patches'], delta=1)
pos_embed = layers.Embedding(input_dim=hp['num_patches'], output_dim=hp['hidden_dim'])(positions)
print(f"patch embedding : {patch_embed.shape}")
print(f"position embeding : {pos_embed.shape}")
#Patch + Position Embedding
embed = patch_embed + pos_embed
#Token
token = ClassToken()(embed)
x = layers.Concatenate(axis=1)([token, embed]) #(None, 257, 256)
#Transformer encoder
for _ in range(hp['num_layers']):
x = transformer_encoder(x, hp)
x = layers.LayerNormalization()(x)
x = x[:, 0, :]
x = layers.Dense(hp['num_classes'], activation='softmax')(x)
model = Model(inputs, x)
return model
|