Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Sequence | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer, | |
| build_norm_layer) | |
| from mmengine.utils import digit_version | |
| from mmpretrain.registry import MODELS | |
| from .base_backbone import BaseBackbone | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x): | |
| return self.fn(x) + x | |
| class ConvMixer(BaseBackbone): | |
| """ConvMixer. . | |
| A PyTorch implementation of : `Patches Are All You Need? | |
| <https://arxiv.org/pdf/2201.09792.pdf>`_ | |
| Modified from the `official repo | |
| <https://github.com/locuslab/convmixer/blob/main/convmixer.py>`_ | |
| and `timm | |
| <https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convmixer.py>`_. | |
| Args: | |
| arch (str | dict): The model's architecture. If string, it should be | |
| one of architecture in ``ConvMixer.arch_settings``. And if dict, it | |
| should include the following two keys: | |
| - embed_dims (int): The dimensions of patch embedding. | |
| - depth (int): Number of repetitions of ConvMixer Layer. | |
| - patch_size (int): The patch size. | |
| - kernel_size (int): The kernel size of depthwise conv layers. | |
| Defaults to '768/32'. | |
| in_channels (int): Number of input image channels. Defaults to 3. | |
| patch_size (int): The size of one patch in the patch embed layer. | |
| Defaults to 7. | |
| norm_cfg (dict): The config dict for norm layers. | |
| Defaults to ``dict(type='BN')``. | |
| act_cfg (dict): The config dict for activation after each convolution. | |
| Defaults to ``dict(type='GELU')``. | |
| out_indices (Sequence | int): Output from which stages. | |
| Defaults to -1, means the last stage. | |
| frozen_stages (int): Stages to be frozen (all param fixed). | |
| Defaults to 0, which means not freezing any parameters. | |
| init_cfg (dict, optional): Initialization config dict. | |
| """ | |
| arch_settings = { | |
| '768/32': { | |
| 'embed_dims': 768, | |
| 'depth': 32, | |
| 'patch_size': 7, | |
| 'kernel_size': 7 | |
| }, | |
| '1024/20': { | |
| 'embed_dims': 1024, | |
| 'depth': 20, | |
| 'patch_size': 14, | |
| 'kernel_size': 9 | |
| }, | |
| '1536/20': { | |
| 'embed_dims': 1536, | |
| 'depth': 20, | |
| 'patch_size': 7, | |
| 'kernel_size': 9 | |
| }, | |
| } | |
| def __init__(self, | |
| arch='768/32', | |
| in_channels=3, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='GELU'), | |
| out_indices=-1, | |
| frozen_stages=0, | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| if isinstance(arch, str): | |
| assert arch in self.arch_settings, \ | |
| f'Unavailable arch, please choose from ' \ | |
| f'({set(self.arch_settings)}) or pass a dict.' | |
| arch = self.arch_settings[arch] | |
| elif isinstance(arch, dict): | |
| essential_keys = { | |
| 'embed_dims', 'depth', 'patch_size', 'kernel_size' | |
| } | |
| assert isinstance(arch, dict) and essential_keys <= set(arch), \ | |
| f'Custom arch needs a dict with keys {essential_keys}' | |
| self.embed_dims = arch['embed_dims'] | |
| self.depth = arch['depth'] | |
| self.patch_size = arch['patch_size'] | |
| self.kernel_size = arch['kernel_size'] | |
| self.act = build_activation_layer(act_cfg) | |
| # check out indices and frozen stages | |
| if isinstance(out_indices, int): | |
| out_indices = [out_indices] | |
| assert isinstance(out_indices, Sequence), \ | |
| f'"out_indices" must by a sequence or int, ' \ | |
| f'get {type(out_indices)} instead.' | |
| for i, index in enumerate(out_indices): | |
| if index < 0: | |
| out_indices[i] = self.depth + index | |
| assert out_indices[i] >= 0, f'Invalid out_indices {index}' | |
| self.out_indices = out_indices | |
| self.frozen_stages = frozen_stages | |
| # Set stem layers | |
| self.stem = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, | |
| self.embed_dims, | |
| kernel_size=self.patch_size, | |
| stride=self.patch_size), self.act, | |
| build_norm_layer(norm_cfg, self.embed_dims)[1]) | |
| # Set conv2d according to torch version | |
| convfunc = nn.Conv2d | |
| if digit_version(torch.__version__) < digit_version('1.9.0'): | |
| convfunc = Conv2dAdaptivePadding | |
| # Repetitions of ConvMixer Layer | |
| self.stages = nn.Sequential(*[ | |
| nn.Sequential( | |
| Residual( | |
| nn.Sequential( | |
| convfunc( | |
| self.embed_dims, | |
| self.embed_dims, | |
| self.kernel_size, | |
| groups=self.embed_dims, | |
| padding='same'), self.act, | |
| build_norm_layer(norm_cfg, self.embed_dims)[1])), | |
| nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1), | |
| self.act, | |
| build_norm_layer(norm_cfg, self.embed_dims)[1]) | |
| for _ in range(self.depth) | |
| ]) | |
| self._freeze_stages() | |
| def forward(self, x): | |
| x = self.stem(x) | |
| outs = [] | |
| for i, stage in enumerate(self.stages): | |
| x = stage(x) | |
| if i in self.out_indices: | |
| outs.append(x) | |
| # x = self.pooling(x).flatten(1) | |
| return tuple(outs) | |
| def train(self, mode=True): | |
| super(ConvMixer, self).train(mode) | |
| self._freeze_stages() | |
| def _freeze_stages(self): | |
| for i in range(self.frozen_stages): | |
| stage = self.stages[i] | |
| stage.eval() | |
| for param in stage.parameters(): | |
| param.requires_grad = False | |