liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils.norm import build_norm_layer
from ..utils.sparse_modules import SparseHelper
from .base import BaseSelfSupervisor
@MODELS.register_module()
class SparK(BaseSelfSupervisor):
"""Implementation of SparK.
Implementation of `Designing BERT for Convolutional Networks: Sparse and
Hierarchical Masked Modeling <https://arxiv.org/abs/2301.03580>`_.
Modified from
https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py
"""
def __init__(
self,
backbone: dict,
neck: dict,
head: dict,
pretrained: Optional[str] = None,
data_preprocessor: Optional[dict] = None,
input_size: int = 224,
downsample_raito: int = 32,
mask_ratio: float = 0.6,
enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'),
enc_dec_norm_dim: int = 2048,
init_cfg: Optional[dict] = None,
) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
self.input_size = input_size
self.downsample_raito = downsample_raito
feature_map_size = input_size // downsample_raito
self.feature_map_size = feature_map_size
self.mask_ratio = mask_ratio
self.len_keep = round(feature_map_size * feature_map_size *
(1 - mask_ratio))
self.enc_dec_norm_cfg = enc_dec_norm_cfg
self.enc_dec_norms = nn.ModuleList()
self.enc_dec_projectors = nn.ModuleList()
self.mask_tokens = nn.ParameterList()
proj_out_dim = self.neck.feature_dim
for i in range(len(self.backbone.out_indices)):
enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg,
enc_dec_norm_dim)
self.enc_dec_norms.append(enc_dec_norm)
kernel_size = 1 if i <= 0 else 3
proj_layer = nn.Conv2d(
enc_dec_norm_dim,
proj_out_dim,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
bias=True)
if i == 0 and enc_dec_norm_dim == proj_out_dim:
proj_layer = nn.Identity()
self.enc_dec_projectors.append(proj_layer)
mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1))
trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02)
self.mask_tokens.append(mask_token)
enc_dec_norm_dim //= 2
proj_out_dim //= 2
feature_map_size *= 2
def mask(self,
shape: torch.Size,
device: Union[torch.device, str],
generator: Optional[torch.Generator] = None):
"""Mask generation.
Args:
shape (torch.Size): The shape of the input images.
device (Union[torch.device, str]): The device of the tensor.
generator (torch.Generator, optional): Generator for random
functions. Defaults to None
Returns:
torch.Tensor: The generated mask.
"""
B, C, H, W = shape
f = self.feature_map_size
idx = torch.rand(B, f * f, generator=generator).argsort(dim=1)
idx = idx[:, :self.len_keep].to(device) # (B, len_keep)
return torch.zeros(
B, f * f, dtype=torch.bool, device=device).scatter_(
dim=1, index=idx, value=True).view(B, 1, f, f)
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# active mask of feature map, (B, 1, f, f)
active_mask_feature_map = self.mask(inputs.shape, inputs.device)
SparseHelper._cur_active = active_mask_feature_map
# active mask of original input, (B, 1, H, W)
active_mask_origin = active_mask_feature_map.repeat_interleave(
self.downsample_raito,
2).repeat_interleave(self.downsample_raito, 3)
masked_img = inputs * active_mask_origin
# get hierarchical encoded sparse features in a list
# containing four feature maps
feature_maps = self.backbone(masked_img)
# from the smallest feature map to the largest
feature_maps = list(feature_maps)
feature_maps.reverse()
cur_active = active_mask_feature_map
feature_maps_to_dec = []
for i, feature_map in enumerate(feature_maps):
if feature_map is not None:
# fill in empty positions with [mask] embeddings
feature_map = self.enc_dec_norms[i](feature_map)
mask_token = self.mask_tokens[i].expand_as(feature_map)
feature_map = torch.where(
cur_active.expand_as(feature_map), feature_map,
mask_token.to(feature_map.dtype))
feature_map = self.enc_dec_projectors[i](feature_map)
feature_maps_to_dec.append(feature_map)
# dilate the mask map
cur_active = cur_active.repeat_interleave(
2, dim=2).repeat_interleave(
2, dim=3)
# decode and reconstruct
rec_img = self.neck(feature_maps_to_dec)
# compute loss
loss = self.head(rec_img, inputs, active_mask_feature_map)
losses = dict(loss=loss)
return losses