File size: 6,079 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_

from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils.norm import build_norm_layer
from ..utils.sparse_modules import SparseHelper
from .base import BaseSelfSupervisor


@MODELS.register_module()
class SparK(BaseSelfSupervisor):
    """Implementation of SparK.

    Implementation of `Designing BERT for Convolutional Networks: Sparse and
    Hierarchical Masked Modeling <https://arxiv.org/abs/2301.03580>`_.

    Modified from
    https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py
    """

    def __init__(
        self,
        backbone: dict,
        neck: dict,
        head: dict,
        pretrained: Optional[str] = None,
        data_preprocessor: Optional[dict] = None,
        input_size: int = 224,
        downsample_raito: int = 32,
        mask_ratio: float = 0.6,
        enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'),
        enc_dec_norm_dim: int = 2048,
        init_cfg: Optional[dict] = None,
    ) -> None:
        super().__init__(
            backbone=backbone,
            neck=neck,
            head=head,
            pretrained=pretrained,
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg)
        self.input_size = input_size
        self.downsample_raito = downsample_raito
        feature_map_size = input_size // downsample_raito
        self.feature_map_size = feature_map_size

        self.mask_ratio = mask_ratio
        self.len_keep = round(feature_map_size * feature_map_size *
                              (1 - mask_ratio))

        self.enc_dec_norm_cfg = enc_dec_norm_cfg
        self.enc_dec_norms = nn.ModuleList()
        self.enc_dec_projectors = nn.ModuleList()
        self.mask_tokens = nn.ParameterList()

        proj_out_dim = self.neck.feature_dim
        for i in range(len(self.backbone.out_indices)):
            enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg,
                                            enc_dec_norm_dim)
            self.enc_dec_norms.append(enc_dec_norm)

            kernel_size = 1 if i <= 0 else 3
            proj_layer = nn.Conv2d(
                enc_dec_norm_dim,
                proj_out_dim,
                kernel_size=kernel_size,
                stride=1,
                padding=kernel_size // 2,
                bias=True)
            if i == 0 and enc_dec_norm_dim == proj_out_dim:
                proj_layer = nn.Identity()
            self.enc_dec_projectors.append(proj_layer)

            mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1))
            trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02)
            self.mask_tokens.append(mask_token)

            enc_dec_norm_dim //= 2
            proj_out_dim //= 2
            feature_map_size *= 2

    def mask(self,
             shape: torch.Size,
             device: Union[torch.device, str],
             generator: Optional[torch.Generator] = None):
        """Mask generation.

        Args:
            shape (torch.Size): The shape of the input images.
            device (Union[torch.device, str]): The device of the tensor.
            generator (torch.Generator, optional): Generator for random
                functions. Defaults to None
        Returns:
            torch.Tensor: The generated mask.
        """
        B, C, H, W = shape
        f = self.feature_map_size
        idx = torch.rand(B, f * f, generator=generator).argsort(dim=1)
        idx = idx[:, :self.len_keep].to(device)  # (B, len_keep)
        return torch.zeros(
            B, f * f, dtype=torch.bool, device=device).scatter_(
                dim=1, index=idx, value=True).view(B, 1, f, f)

    def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
             **kwargs) -> Dict[str, torch.Tensor]:
        """The forward function in training.

        Args:
            inputs (List[torch.Tensor]): The input images.
            data_samples (List[DataSample]): All elements required
                during the forward function.
        Returns:
            Dict[str, torch.Tensor]: A dictionary of loss components.
        """

        # active mask of feature map, (B, 1, f, f)
        active_mask_feature_map = self.mask(inputs.shape, inputs.device)
        SparseHelper._cur_active = active_mask_feature_map

        # active mask of original input, (B, 1, H, W)
        active_mask_origin = active_mask_feature_map.repeat_interleave(
            self.downsample_raito,
            2).repeat_interleave(self.downsample_raito, 3)
        masked_img = inputs * active_mask_origin

        # get hierarchical encoded sparse features in a list
        # containing four feature maps
        feature_maps = self.backbone(masked_img)

        # from the smallest feature map to the largest
        feature_maps = list(feature_maps)
        feature_maps.reverse()

        cur_active = active_mask_feature_map
        feature_maps_to_dec = []
        for i, feature_map in enumerate(feature_maps):
            if feature_map is not None:
                # fill in empty positions with [mask] embeddings
                feature_map = self.enc_dec_norms[i](feature_map)
                mask_token = self.mask_tokens[i].expand_as(feature_map)
                feature_map = torch.where(
                    cur_active.expand_as(feature_map), feature_map,
                    mask_token.to(feature_map.dtype))
                feature_map = self.enc_dec_projectors[i](feature_map)
            feature_maps_to_dec.append(feature_map)

            # dilate the mask map
            cur_active = cur_active.repeat_interleave(
                2, dim=2).repeat_interleave(
                    2, dim=3)

        # decode and reconstruct
        rec_img = self.neck(feature_maps_to_dec)

        # compute loss
        loss = self.head(rec_img, inputs, active_mask_feature_map)
        losses = dict(loss=loss)
        return losses