File size: 3,915 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from ..swinv2_model import SwinTransformerV2
class SwinTransformerV2ForSimMIM(SwinTransformerV2):
def __init__(self, **kwargs):
super().__init__(**kwargs)
assert self.num_classes == 0
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
trunc_normal_(self.mask_token, mean=0., std=.02)
def forward(self, x, mask):
x = self.patch_embed(x)
assert mask is not None
B, L, _ = x.shape
mask_tokens = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
x = x * (1. - w) + mask_tokens * w
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
x = x.transpose(1, 2)
B, C, L = x.shape
H = W = int(L ** 0.5)
x = x.reshape(B, C, H, W)
return x
@torch.jit.ignore
def no_weight_decay(self):
return super().no_weight_decay() | {'mask_token'}
class MiMModel(nn.Module):
def __init__(self, encoder, encoder_stride, in_chans, patch_size):
super().__init__()
self.encoder = encoder
self.encoder_stride = encoder_stride
self.in_chans = in_chans
self.patch_size = patch_size
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=self.encoder.num_features,
out_channels=self.encoder_stride ** 2 * self.in_chans,
kernel_size=1),
nn.PixelShuffle(self.encoder_stride),
)
# self.in_chans = self.encoder.in_chans
# self.patch_size = self.encoder.patch_size
def forward(self, x, mask):
z = self.encoder(x, mask)
x_rec = self.decoder(z)
mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(
self.patch_size, 2).unsqueeze(1).contiguous()
loss_recon = F.l1_loss(x, x_rec, reduction='none')
loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
return loss
@torch.jit.ignore
def no_weight_decay(self):
if hasattr(self.encoder, 'no_weight_decay'):
return {'encoder.' + i for i in self.encoder.no_weight_decay()}
return {}
@torch.jit.ignore
def no_weight_decay_keywords(self):
if hasattr(self.encoder, 'no_weight_decay_keywords'):
return {'encoder.' + i for i in
self.encoder.no_weight_decay_keywords()}
return {}
def build_mim_model(config):
model_type = config.MODEL.TYPE
if model_type == 'swinv2':
encoder = SwinTransformerV2ForSimMIM(
img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.SWINV2.PATCH_SIZE,
in_chans=config.MODEL.SWINV2.IN_CHANS,
num_classes=0,
embed_dim=config.MODEL.SWINV2.EMBED_DIM,
depths=config.MODEL.SWINV2.DEPTHS,
num_heads=config.MODEL.SWINV2.NUM_HEADS,
window_size=config.MODEL.SWINV2.WINDOW_SIZE,
mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.SWINV2.APE,
patch_norm=config.MODEL.SWINV2.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT)
encoder_stride = 32
in_chans = config.MODEL.SWINV2.IN_CHANS
patch_size = config.MODEL.SWINV2.PATCH_SIZE
else:
raise NotImplementedError(f"Unknown pre-train model: {model_type}")
model = MiMModel(encoder=encoder, encoder_stride=encoder_stride,
in_chans=in_chans, patch_size=patch_size)
return model
|