biomass / models /backbone.py
juansensio's picture
Upload 5 files
c5d0550 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
from torch import nn
import torchvision
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
import math
import warnings
from collections import OrderedDict
from torch import Tensor
import torch.nn.functional as F
from typing import Callable, Optional, Tuple, Union
from functools import partial
import pdb
class MaskingGenerator:
def __init__(
self,
input_size,
num_masking_patches=None,
min_num_patches=4,
max_num_patches=None,
min_aspect=0.3,
max_aspect=None,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = num_masking_patches
self.min_num_patches = min_num_patches
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self):
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height,
self.width,
self.min_num_patches,
self.max_num_patches,
self.num_masking_patches,
self.log_aspect_ratio[0],
self.log_aspect_ratio[1],
)
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for attempt in range(10):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top : top + h, left : left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self, num_masking_patches=0):
mask = np.zeros(shape=self.get_shape(), dtype=np.bool)
mask_count = 0
while mask_count < num_masking_patches:
max_mask_patches = num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=False):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
return F.interpolate(input, size, scale_factor, mode, align_corners)
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU(),
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU(),
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor) -> Tensor:
#pdb.set_trace()
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x))
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
def make_2tuple(x):
if isinstance(x, tuple):
assert len(tuple) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class DinoVisionTransformer(nn.Module):
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=0,
global_pool="token",
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.0,
qkv_bias=True,
representation_size=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
weight_init="",
init_values=1.,
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
ffn_layer="mlp",
drop_path_uniform=False,
patch_drop=0.0,
sin_cos_embeddings=False,
local_crops_size=96,
multiple_pos_embeddings=False,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init: (str): weight init scheme
init_values: (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool in ("", "avg", "token")
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.grad_checkpointing = False
self.sin_cos_embeddings = sin_cos_embeddings
self.multiple_pos_embeddings = multiple_pos_embeddings
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if self.sin_cos_embeddings:
self.pos_embed = torch.Tensor(())
logger.info("using sin-cos fixed embeddings")
pass
elif self.multiple_pos_embeddings:
logger.info("using multiple position embeddings (one for global one for local)")
self.pos_embeds = nn.ParameterDict()
self.pos_embeds[str(img_size)] = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
n_local_patches = (local_crops_size // patch_size) ** 2
self.pos_embeds[str(local_crops_size)] = nn.Parameter(torch.zeros(1, n_local_patches, embed_dim))
self.pos_embed = None
else:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
#print("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglu":
#print("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFN
elif ffn_layer == "identity":
#print("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
self.blocks = nn.ModuleList(
[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
)
for i in range(depth)
]
)
use_fc_norm = self.global_pool == "avg"
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Representation layer. Used for original ViT models w/ in21k pretraining.
self.representation_size = representation_size
self.pre_logits = nn.Identity()
if representation_size:
self._reset_representation(representation_size)
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
final_chs = self.representation_size if self.representation_size else self.embed_dim
self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
self.mask_generator = MaskingGenerator(
input_size=(img_size // patch_size, img_size // patch_size),
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
)
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
# if weight_init != "skip":
# self.init_weights(weight_init)
def _reset_representation(self, representation_size):
self.representation_size = representation_size
if self.representation_size:
self.pre_logits = nn.Sequential(
OrderedDict([("fc", nn.Linear(self.embed_dim, self.representation_size)), ("act", nn.Tanh())])
)
else:
self.pre_logits = nn.Identity()
def init_weights(self, mode=""):
assert mode in ("jax", "jax_nlhb", "moco", "")
head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
elif self.pos_embeds:
for v in self.pos_embeds.values():
trunc_normal_(v, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
def _init_weights(self, m):
# this fn left here for compat with downstream users
init_weights_vit_timm(m)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=""):
_load_weights(self, checkpoint_path, prefix)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token", "dist_token"}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ("", "avg", "token")
self.global_pool = global_pool
if representation_size is not None:
self._reset_representation(representation_size)
final_chs = self.representation_size if self.representation_size else self.embed_dim
self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == "avg" else x[:, 0]
x = self.fc_norm(x)
x = self.pre_logits(x)
return x if pre_logits else self.head(x)
def interpolate_pos_encoding(self, x, w, h):
if self.sin_cos_embeddings:
w0 = w // self.patch_embed.patch_size[0]
step_coef = (w0-1) / 3.14
omega_coef = 10000
sin_cos_embed = get_2d_sincos_pos_embed_cached_device(
embed_dim=x.shape[-1], grid_size=w0, step_coef=step_coef, omega_coef=omega_coef, device=x.device, cls_token=True
)
return sin_cos_embed
elif self.multiple_pos_embeddings:
_m = sum((v.mean() * 0 for v in self.pos_embeds.values()))
pos_embed = self.pos_embeds[str(w)] + _m
class_pos_embed = torch.zeros_like(pos_embed[:1,:1])
return torch.cat((class_pos_embed, pos_embed), dim=1)
else:
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size[0]
h0 = h // self.patch_embed.patch_size[0]
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode="bicubic", align_corners=True, recompute_scale_factor=True
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def mask_patches_with_probability_p(self, x, mask_ratio_tuple, p):
B, N, _ = x.shape
n_samples_masked = int(B * p)
mask_ratio_min, mask_ratio_max = mask_ratio_tuple
masks = torch.stack(
[
torch.BoolTensor(self.mask_generator(int(N * random.uniform(mask_ratio_min, mask_ratio_max))))
for _ in range(0, n_samples_masked)
]
+ [torch.BoolTensor(self.mask_generator(0)) for _ in range(n_samples_masked, B)]
).to(
x.device
)
masks = masks[torch.randperm(B, device=x.device)].flatten(1)
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
return x, masks
def mask_patches_with_probability_p_upperbound(self, x, mask_ratio_tuple, p):
B, N, _ = x.shape
n_samples_masked = int(B * p)
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
upperbound = 0
masks_list = []
for i in range(0, n_samples_masked):
prob_min = probs[i]
prob_max = probs[i+1]
masks_list.append(torch.BoolTensor(self.mask_generator(int(N * random.uniform(prob_min, prob_max)))))
upperbound += int(N * prob_max)
for i in range(n_samples_masked, B):
masks_list.append(torch.BoolTensor(self.mask_generator(0)))
masks = torch.stack(masks_list).to(x.device)
masks = masks[torch.randperm(B, device=x.device)].flatten(1)
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
return x, masks, upperbound
def prepare_tokens(self, x, mask_ratio_tuple=(0.0, 0.0), mask_sample_probability=0.0, ibot_balanced_masking=False):
B, nc, w, h = x.shape
x = self.patch_embed(x)
masks = None
n_masked_patches_upperbound = None
cls_token = self.cls_token
do_ibot = max(mask_ratio_tuple) > 0.0 and mask_sample_probability > 0.0
if do_ibot:
if ibot_balanced_masking:
logger.debug("using balanced masking")
x, masks, n_masked_patches_upperbound = self.mask_patches_with_probability_p_upperbound(
x, mask_ratio_tuple=mask_ratio_tuple, p=mask_sample_probability
)
else:
logger.debug("not using balanced masking")
x, masks = self.mask_patches_with_probability_p(
x, mask_ratio_tuple=mask_ratio_tuple, p=mask_sample_probability
)
else:
cls_token = cls_token + 0 * self.mask_token # hack to use the mask_token param to not crash ddp...
x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.interpolate_pos_encoding(x, w, h))
return x, masks, n_masked_patches_upperbound
def forward_features(self, x, mask_ratio_tuple=(0.0, 0.0), mask_sample_probability=0.0, ibot_balanced_masking=False):
x, masks, n_masked_patches_upperbound = self.prepare_tokens(x, mask_ratio_tuple, mask_sample_probability, ibot_balanced_masking)
for blk in self.blocks:
x = blk(x)
x_norm = self.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_patchtokens": x_norm[:, 1:],
"x_prenorm": x,
"masks": masks,
"n_masked_patches_upperbound": n_masked_patches_upperbound,
}
def get_intermediate_layers(self, x, n=1):
x, _, _ = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if len(self.blocks) - i <= n:
output.append(self.norm(x))
return output
def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return ret["x_norm_clstoken"]
class AdaptivePadding(nn.Module):
"""Applies padding to input (if needed) so that input can get fully covered
by filter you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
input. The "corner" mode would pad zero to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel:
stride (int | tuple): Stride of the filter. Default: 1:
dilation (int | tuple): Spacing between kernel elements.
Default: 1.
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super(AdaptivePadding, self).__init__()
assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h +
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w +
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == 'corner':
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == 'same':
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2
])
return x
class SSLVisionTransformer(DinoVisionTransformer):
"""Vision Transformer.
"""
def __init__(self,
interpolate_mode='bicubic',
init_cfg=None,
pretrained=None,
img_size=224,
patch_size=16,
#embed_dim=1024,
#depth=24,
#num_heads=16,
mlp_ratio=4,
qkv_bias=True,
init_values=1.,
out_indices=(4, 11, 17, 23),
final_norm=False,
with_cls_token=True,
output_cls_token=True,
frozen_stages=100,
*args, **kwargs):
super(SSLVisionTransformer, self).__init__(*args, **kwargs)
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
if len(self.blocks)==1:
self.blocks = self.blocks[0]
if isinstance(out_indices, int):
if out_indices == -1:
out_indices = len(self.blocks) - 1
self.out_indices = [out_indices]
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
self.out_indices = out_indices
else:
raise TypeError('out_indices must be type of int, list or tuple')
self.interpolate_mode = interpolate_mode
self.pretrained = pretrained
self.frozen_stages = frozen_stages
self.detach = False
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.final_norm = final_norm
self.patch_size = self.patch_embed.patch_size
self.adapad = AdaptivePadding(kernel_size=self.patch_size, stride=self.patch_size, padding='same')
if pretrained:
self.init_weights(pretrained)
self._freeze_stages()
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
pos_h, pos_w = pos_shape
cls_token_weight = pos_embed[:, 0]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = resize(
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed
def init_weights(self, pretrained):
print("init_weights", pretrained)
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
checkpoint = torch.load(pretrained, map_location='cpu')
if 'state_dict' in checkpoint:
# timm checkpoint
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
# deit checkpoint
state_dict = checkpoint['model']
elif 'teacher' in checkpoint:
# dino eval checkpoint
state_dict = checkpoint['teacher']
else:
state_dict = checkpoint
if len([k for k in state_dict.keys() if 'teacher.backbone.' in k]) > 0:
state_dict = {k.replace('teacher.backbone.', ''):v for k,v in state_dict.items() if 'teacher.backbone' in k}
if len([k for k in state_dict.keys() if 'backbone.' in k]) > 0:
state_dict = {k.replace('backbone.', ''):v for k,v in state_dict.items()}
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
print(f'Resize the pos_embed shape from '
f'{state_dict["pos_embed"].shape} to '
f'{self.pos_embed.shape}')
h, w = (224, 224) # self.img_size
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'],
(h // self.patch_size[0], w // self.patch_size[1]),
(pos_size, pos_size), self.interpolate_mode)
self.load_state_dict(state_dict)
else:
super(SSLVisionTransformer, self).init_weights()
def forward(self, x):
with torch.set_grad_enabled(not self.detach):
_, _, old_w, old_h = x.shape
xx = self.adapad(x)
x = F.pad(x, (0, xx.shape[-1] - x.shape[-1], 0, xx.shape[-2] - x.shape[-2]))
B, nc, w, h = x.shape
x, _, _ = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
outs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in self.out_indices:
if self.with_cls_token:
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, w // self.patch_size[0], h // self.patch_size[1],
C).permute(0, 3, 1, 2).contiguous()
if self.output_cls_token:
out = [out, x[:, 0]]
else:
out = [out]
if self.final_norm:
out = [self.norm(o) for o in out]
if self.detach:
out = [o.detach() for o in out]
outs.append(out)
return tuple(outs)
def train(self, mode=True):
super(SSLVisionTransformer, self).train(mode)
self.detach = False
self._freeze_stages()
def _freeze_stages(self):
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:
self.patch_embed.eval()
for m in [self.patch_embed]:
for param in m.parameters():
param.requires_grad = False
self.cls_token.requires_grad = False
self.pos_embed.requires_grad = False
self.mask_token.requires_grad = False
if self.frozen_stages >= len(self.blocks) - 1:
self.norm.eval()
for param in self.norm.parameters():
param.requires_grad = False
self.detach = True
for i, layer in enumerate(self.blocks):
if i <= self.frozen_stages:
layer.eval()
for param in layer.parameters():
param.requires_grad = False