Spaces:
Runtime error
Runtime error
File size: 6,079 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils.norm import build_norm_layer
from ..utils.sparse_modules import SparseHelper
from .base import BaseSelfSupervisor
@MODELS.register_module()
class SparK(BaseSelfSupervisor):
"""Implementation of SparK.
Implementation of `Designing BERT for Convolutional Networks: Sparse and
Hierarchical Masked Modeling <https://arxiv.org/abs/2301.03580>`_.
Modified from
https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py
"""
def __init__(
self,
backbone: dict,
neck: dict,
head: dict,
pretrained: Optional[str] = None,
data_preprocessor: Optional[dict] = None,
input_size: int = 224,
downsample_raito: int = 32,
mask_ratio: float = 0.6,
enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'),
enc_dec_norm_dim: int = 2048,
init_cfg: Optional[dict] = None,
) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
self.input_size = input_size
self.downsample_raito = downsample_raito
feature_map_size = input_size // downsample_raito
self.feature_map_size = feature_map_size
self.mask_ratio = mask_ratio
self.len_keep = round(feature_map_size * feature_map_size *
(1 - mask_ratio))
self.enc_dec_norm_cfg = enc_dec_norm_cfg
self.enc_dec_norms = nn.ModuleList()
self.enc_dec_projectors = nn.ModuleList()
self.mask_tokens = nn.ParameterList()
proj_out_dim = self.neck.feature_dim
for i in range(len(self.backbone.out_indices)):
enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg,
enc_dec_norm_dim)
self.enc_dec_norms.append(enc_dec_norm)
kernel_size = 1 if i <= 0 else 3
proj_layer = nn.Conv2d(
enc_dec_norm_dim,
proj_out_dim,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
bias=True)
if i == 0 and enc_dec_norm_dim == proj_out_dim:
proj_layer = nn.Identity()
self.enc_dec_projectors.append(proj_layer)
mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1))
trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02)
self.mask_tokens.append(mask_token)
enc_dec_norm_dim //= 2
proj_out_dim //= 2
feature_map_size *= 2
def mask(self,
shape: torch.Size,
device: Union[torch.device, str],
generator: Optional[torch.Generator] = None):
"""Mask generation.
Args:
shape (torch.Size): The shape of the input images.
device (Union[torch.device, str]): The device of the tensor.
generator (torch.Generator, optional): Generator for random
functions. Defaults to None
Returns:
torch.Tensor: The generated mask.
"""
B, C, H, W = shape
f = self.feature_map_size
idx = torch.rand(B, f * f, generator=generator).argsort(dim=1)
idx = idx[:, :self.len_keep].to(device) # (B, len_keep)
return torch.zeros(
B, f * f, dtype=torch.bool, device=device).scatter_(
dim=1, index=idx, value=True).view(B, 1, f, f)
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# active mask of feature map, (B, 1, f, f)
active_mask_feature_map = self.mask(inputs.shape, inputs.device)
SparseHelper._cur_active = active_mask_feature_map
# active mask of original input, (B, 1, H, W)
active_mask_origin = active_mask_feature_map.repeat_interleave(
self.downsample_raito,
2).repeat_interleave(self.downsample_raito, 3)
masked_img = inputs * active_mask_origin
# get hierarchical encoded sparse features in a list
# containing four feature maps
feature_maps = self.backbone(masked_img)
# from the smallest feature map to the largest
feature_maps = list(feature_maps)
feature_maps.reverse()
cur_active = active_mask_feature_map
feature_maps_to_dec = []
for i, feature_map in enumerate(feature_maps):
if feature_map is not None:
# fill in empty positions with [mask] embeddings
feature_map = self.enc_dec_norms[i](feature_map)
mask_token = self.mask_tokens[i].expand_as(feature_map)
feature_map = torch.where(
cur_active.expand_as(feature_map), feature_map,
mask_token.to(feature_map.dtype))
feature_map = self.enc_dec_projectors[i](feature_map)
feature_maps_to_dec.append(feature_map)
# dilate the mask map
cur_active = cur_active.repeat_interleave(
2, dim=2).repeat_interleave(
2, dim=3)
# decode and reconstruct
rec_img = self.neck(feature_maps_to_dec)
# compute loss
loss = self.head(rec_img, inputs, active_mask_feature_map)
losses = dict(loss=loss)
return losses
|