|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class PatchEmbedding(nn.Module): |
|
def __init__(self, img_size=128, patch_size=8, in_channels=3, embed_dim=768): |
|
super().__init__() |
|
self.patch_size = patch_size |
|
self.num_patches = (img_size // patch_size) ** 2 |
|
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) |
|
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) |
|
|
|
def forward(self, x): |
|
B = x.shape[0] |
|
x = self.proj(x).flatten(2).transpose(1, 2) |
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
x = torch.cat([cls_tokens, x], dim=1) |
|
x += self.pos_embedding |
|
return x |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(embed_dim) |
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) |
|
self.norm2 = nn.LayerNorm(embed_dim) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(embed_dim, mlp_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(mlp_dim, embed_dim), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] |
|
x = x + self.mlp(self.norm2(x)) |
|
return x |
|
|
|
class VisionTransformer(nn.Module): |
|
def __init__(self, img_size=128, patch_size=8, num_classes=10, embed_dim=768, depth=8, num_heads=12, mlp_dim=2048, dropout=0.1): |
|
super().__init__() |
|
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels=3, embed_dim=embed_dim) |
|
self.transformer = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(depth)]) |
|
self.norm = nn.LayerNorm(embed_dim) |
|
self.head = nn.Linear(embed_dim, num_classes) |
|
|
|
def forward(self, x): |
|
x = self.patch_embed(x) |
|
x = self.transformer(x) |
|
x = self.norm(x[:, 0]) |
|
return self.head(x) |
|
|