Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| from einops import rearrange | |
| from mmengine.model import BaseModule | |
| from mmengine.model.weight_init import trunc_normal_ | |
| from torch import nn | |
| from mmpretrain.models.backbones import BEiTViT | |
| from mmpretrain.models.utils import NormEMAVectorQuantizer, resize_pos_embed | |
| from mmpretrain.registry import MODELS | |
| from mmpretrain.structures import DataSample | |
| from .base import BaseSelfSupervisor | |
| class VQKD(BaseModule): | |
| """Vector-Quantized Knowledge Distillation. | |
| The module only contains encoder and VectorQuantizer part | |
| Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py | |
| Args: | |
| encoder_config (dict): The config of encoder. | |
| decoder_config (dict, optional): The config of decoder. Currently, | |
| VQKD only support to build encoder. Defaults to None. | |
| num_embed (int): Number of embedding vectors in the codebook. Defaults | |
| to 8192. | |
| embed_dims (int) : The dimension of embedding vectors in the codebook. | |
| Defaults to 32. | |
| decay (float): The decay parameter of EMA. Defaults to 0.99. | |
| beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1. | |
| quantize_kmeans_init (bool): Whether to use k-means to initialize the | |
| VectorQuantizer. Defaults to True. | |
| init_cfg (dict or List[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ # noqa: E501 | |
| def __init__(self, | |
| encoder_config: dict, | |
| decoder_config: Optional[dict] = None, | |
| num_embed: int = 8192, | |
| embed_dims: int = 32, | |
| decay: float = 0.99, | |
| beta: float = 1.0, | |
| quantize_kmeans_init: bool = True, | |
| init_cfg: Optional[dict] = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.encoder = BEiTViT(**encoder_config) | |
| if decoder_config is not None: | |
| self.decoder = BEiTViT(**decoder_config) | |
| self.quantize = NormEMAVectorQuantizer( | |
| num_embed=num_embed, | |
| embed_dims=embed_dims, | |
| beta=beta, | |
| decay=decay, | |
| kmeans_init=quantize_kmeans_init, | |
| ) | |
| # task layer | |
| self.encode_task_layer = nn.Sequential( | |
| nn.Linear(self.encoder.arch_settings['embed_dims'], | |
| self.encoder.arch_settings['embed_dims']), nn.Tanh(), | |
| nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims)) | |
| def get_tokens(self, x: torch.Tensor) -> dict: | |
| """Get tokens for beit pre-training.""" | |
| _, embed_ind, _ = self.encode(x) | |
| output = {} | |
| output['token'] = embed_ind.view(x.shape[0], -1) | |
| output['input_img'] = x | |
| return output | |
| def encode( | |
| self, x: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Encode the input images and get corresponding results.""" | |
| encoder_features = self.encoder(x)[0] | |
| B, C, N1, N2 = encoder_features.shape | |
| encoder_features = encoder_features.permute(0, 2, 3, | |
| 1).reshape(B, N1 * N2, C) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| to_quantizer_features = self.encode_task_layer( | |
| encoder_features.type_as(self.encode_task_layer[-1].weight)) | |
| N = to_quantizer_features.shape[1] | |
| h, w = int(math.sqrt(N)), int(math.sqrt(N)) | |
| to_quantizer_features = rearrange( | |
| to_quantizer_features, 'b (h w) c -> b c h w', h=h, | |
| w=w) # reshape for quantizer | |
| quantize, loss, embed_ind = self.quantize(to_quantizer_features) | |
| return quantize, embed_ind, loss | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """The forward function. | |
| Currently, only support to get tokens. | |
| """ | |
| return self.get_tokens(x)['token'] | |
| class BEiTPretrainViT(BEiTViT): | |
| """Vision Transformer for BEiT pre-training. | |
| Args: | |
| arch (str | dict): Vision Transformer architecture. If use string, | |
| choose from 'small', 'base' and 'large'. If use dict, it should | |
| have below keys: | |
| - **embed_dims** (int): The dimensions of embedding. | |
| - **num_layers** (int): The number of transformer encoder layers. | |
| - **num_heads** (int): The number of heads in attention modules. | |
| - **feedforward_channels** (int): The hidden dimensions in | |
| feedforward modules. | |
| Defaults to 'base'. | |
| img_size (int | tuple): The expected input image shape. Because we | |
| support dynamic input shape, just set the argument to the most | |
| common input image shape. Defaults to 224. | |
| patch_size (int | tuple): The patch size in patch embedding. | |
| Defaults to 16. | |
| in_channels (int): The num of input channels. Defaults to 3. | |
| out_indices (Sequence | int): Output from which stages. | |
| Defaults to -1, means the last stage. | |
| drop_rate (float): Probability of an element to be zeroed. | |
| Defaults to 0. | |
| drop_path_rate (float): stochastic depth rate. Defaults to 0. | |
| qkv_bias (bool): Whether to add bias for qkv in attention modules. | |
| Defaults to True. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Defaults to ``dict(type='LN')``. | |
| final_norm (bool): Whether to add a additional layer to normalize | |
| final feature map. Defaults to True. | |
| out_type (str): The type of output features. Please choose from | |
| - ``"cls_token"``: The class token tensor with shape (B, C). | |
| - ``"featmap"``: The feature map tensor from the patch tokens | |
| with shape (B, C, H, W). | |
| - ``"avg_featmap"``: The global averaged feature map tensor | |
| with shape (B, C). | |
| - ``"raw"``: The raw feature tensor includes patch tokens and | |
| class tokens with shape (B, L, C). | |
| It only works without input mask. Defaults to ``"avg_featmap"``. | |
| with_cls_token (bool): Whether concatenating class token into image | |
| tokens as transformer input. Defaults to True. | |
| frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |
| -1 means not freezing any parameters. Defaults to -1. | |
| use_abs_pos_emb (bool): Whether or not use absolute position embedding. | |
| Defaults to False. | |
| use_rel_pos_bias (bool): Whether or not use relative position bias. | |
| Defaults to False. | |
| use_shared_rel_pos_bias (bool): Whether or not use shared relative | |
| position bias. Defaults to True. | |
| layer_scale_init_value (float): The initialization value for | |
| the learnable scaling of attention and FFN. Defaults to 0.1. | |
| interpolate_mode (str): Select the interpolate mode for position | |
| embeding vector resize. Defaults to "bicubic". | |
| patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. | |
| layer_cfgs (Sequence | dict): Configs of each transformer layer in | |
| encoder. Defaults to an empty dict. | |
| init_cfg (dict, optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| arch: str = 'base', | |
| img_size: int = 224, | |
| patch_size: int = 16, | |
| in_channels: int = 3, | |
| out_indices: int = -1, | |
| drop_rate: float = 0, | |
| drop_path_rate: float = 0, | |
| norm_cfg: dict = dict(type='LN', eps=1e-6), | |
| final_norm: bool = True, | |
| out_type: str = 'raw', | |
| frozen_stages: int = -1, | |
| use_abs_pos_emb: bool = False, | |
| use_rel_pos_bias: bool = False, | |
| use_shared_rel_pos_bias: bool = True, | |
| layer_scale_init_value: int = 0.1, | |
| interpolate_mode: str = 'bicubic', | |
| patch_cfg: dict = dict(padding=0), | |
| layer_cfgs: dict = dict(), | |
| init_cfg: Optional[Union[List[dict], dict]] = None) -> None: | |
| super().__init__( | |
| arch=arch, | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_channels=in_channels, | |
| out_indices=out_indices, | |
| drop_rate=drop_rate, | |
| drop_path_rate=drop_path_rate, | |
| norm_cfg=norm_cfg, | |
| final_norm=final_norm, | |
| out_type=out_type, | |
| with_cls_token=True, | |
| frozen_stages=frozen_stages, | |
| use_abs_pos_emb=use_abs_pos_emb, | |
| use_shared_rel_pos_bias=use_shared_rel_pos_bias, | |
| use_rel_pos_bias=use_rel_pos_bias, | |
| layer_scale_init_value=layer_scale_init_value, | |
| interpolate_mode=interpolate_mode, | |
| patch_cfg=patch_cfg, | |
| layer_cfgs=layer_cfgs, | |
| init_cfg=init_cfg) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) | |
| def init_weights(self) -> None: | |
| """Initialize position embedding, patch embedding and cls token.""" | |
| super().init_weights() | |
| if (isinstance(self.init_cfg, dict) | |
| and self.init_cfg['type'] == 'Pretrained'): | |
| # Suppress default init if use pretrained model. | |
| return | |
| trunc_normal_(self.cls_token, std=0.02) | |
| trunc_normal_(self.mask_token, std=0.02) | |
| self.rescale_init_weight() | |
| def rescale_init_weight(self) -> None: | |
| """Rescale the initialized weights.""" | |
| def rescale(param, layer_id): | |
| param.div_(math.sqrt(2.0 * layer_id)) | |
| for layer_id, layer in enumerate(self.layers): | |
| rescale(layer.attn.proj.weight.data, layer_id + 1) | |
| rescale(layer.ffn.layers[1].weight.data, layer_id + 1) | |
| def forward(self, x: torch.Tensor, | |
| mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor]: | |
| """The BEiT style forward function. | |
| The function supports two kind of forward behaviors. If the ``mask`` is | |
| not ``None``, the forward function will be executed as masked image | |
| modeling pre-training; if the ``mask`` is ``None``, the forward | |
| function will call ``super().forward()``, which extract features from | |
| images without mask. | |
| Args: | |
| x (torch.Tensor): Input images, which is of shape (B x C x H x W). | |
| mask (torch.Tensor, optional): Mask for input, which is of shape | |
| (B x patch_resolution[0] x patch_resolution[1]). | |
| Returns: | |
| Tuple[torch.Tensor]: Hidden features. | |
| """ | |
| if mask is None: | |
| return super().forward(x) | |
| else: | |
| x, patch_resolution = self.patch_embed(x) | |
| # replace the masked visual tokens by mask_token | |
| B, L, _ = x.shape | |
| mask_token = self.mask_token.expand(B, L, -1) | |
| w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) | |
| x = x * (1. - w) + mask_token * w | |
| # stole cls_tokens impl from Phil Wang, thanks | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| if self.pos_embed is not None: | |
| x = x + resize_pos_embed( | |
| self.pos_embed, | |
| self.patch_resolution, | |
| patch_resolution, | |
| mode=self.interpolate_mode, | |
| num_extra_tokens=self.num_extra_tokens) | |
| x = self.drop_after_pos(x) | |
| self.shared_rel_pos_bias = self.rel_pos_bias().to( | |
| mask.device) if self.rel_pos_bias is not None else None | |
| outs = [] | |
| for i, layer in enumerate(self.layers): | |
| x = layer(x, rel_pos_bias=self.shared_rel_pos_bias) | |
| if i == len(self.layers) - 1 and self.final_norm: | |
| x = self.norm1(x) | |
| if i in self.out_indices: | |
| outs.append(x) | |
| return tuple(outs) | |
| class BEiT(BaseSelfSupervisor): | |
| """BEiT v1/v2. | |
| Implementation of `BEiT: BERT Pre-Training of Image Transformers | |
| <https://arxiv.org/abs/2106.08254>`_ and `BEiT v2: Masked Image Modeling | |
| with Vector-Quantized Visual Tokenizers | |
| <https://arxiv.org/abs/2208.06366>`_. | |
| """ | |
| def extract_feat(self, inputs: torch.Tensor): | |
| return self.backbone(inputs, mask=None) | |
| def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], | |
| **kwargs) -> Dict[str, torch.Tensor]: | |
| """The forward function in training. | |
| Args: | |
| inputs (List[torch.Tensor]): The input images. | |
| data_samples (List[DataSample]): All elements required | |
| during the forward function. | |
| Returns: | |
| Dict[str, torch.Tensor]: A dictionary of loss components. | |
| """ | |
| mask = torch.stack([data_sample.mask for data_sample in data_samples]) | |
| img_latent = self.backbone(inputs[0], mask) | |
| # inputs[1] is the target image | |
| with torch.no_grad(): | |
| target = self.target_generator(inputs[1]) | |
| target = target.detach() | |
| if self.with_neck: | |
| # BEiT v2 | |
| feats, feats_cls_pt = self.neck( | |
| img_latent, rel_pos_bias=self.backbone.shared_rel_pos_bias) | |
| loss = self.head.loss(feats, feats_cls_pt, target, mask) | |
| else: | |
| # BEiT v1 | |
| loss = self.head.loss(img_latent[0], target, mask) | |
| if isinstance(loss, torch.Tensor): | |
| losses = dict(loss=loss) | |
| return losses | |
| elif isinstance(loss, Tuple): | |
| # the loss_1 and loss_2 are general reconstruction loss (patch | |
| # feature vectors from last layer of backbone) and early state | |
| # reconstruction loss (patch feature vectors from intermediate | |
| # layer of backbone) | |
| loss_1, loss_2 = loss[0], loss[1] | |
| losses = dict() | |
| # the key with prefix 'loss', like loss_1 and loss_2, will be used | |
| # as the final criterion | |
| losses['loss_1'] = loss_1 | |
| losses['loss_2'] = loss_2 | |
| return losses | |