Spaces:
Running
Running
| from typing import Tuple, List, Union | |
| import torch | |
| from torch import nn | |
| from torch.utils.checkpoint import checkpoint | |
| import torch.nn.functional as F | |
| from timm.models.layers import trunc_normal_ | |
| from sam_extension.distillation_models.fastervit import FasterViTLayer | |
| from segment_anything.mobile_encoder.tiny_vit_sam import PatchEmbed, Conv2d_BN, LayerNorm2d, MBConv | |
| class PatchMerging(nn.Module): | |
| def __init__(self, input_resolution, dim, out_dim, activation): | |
| super().__init__() | |
| self.input_resolution = input_resolution | |
| self.dim = dim | |
| self.out_dim = out_dim | |
| self.act = activation() | |
| self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) | |
| stride_c=2 | |
| if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576 | |
| stride_c=1 | |
| self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) | |
| self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) | |
| def forward(self, x): | |
| if x.ndim == 3: | |
| H, W = self.input_resolution | |
| B = len(x) | |
| # (B, C, H, W) | |
| x = x.view(B, H, W, -1).permute(0, 3, 1, 2) | |
| x = self.conv1(x) | |
| x = self.act(x) | |
| x = self.conv2(x) | |
| x = self.act(x) | |
| x = self.conv3(x) | |
| return x | |
| class ConvLayer(nn.Module): | |
| def __init__(self, dim, input_resolution, depth, | |
| activation, | |
| drop_path=0., downsample=None, use_checkpoint=False, | |
| out_dim=None, | |
| conv_expand_ratio=4., | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.input_resolution = input_resolution | |
| self.depth = depth | |
| self.use_checkpoint = use_checkpoint | |
| # build blocks | |
| self.blocks = nn.ModuleList([ | |
| MBConv(dim, dim, conv_expand_ratio, activation, | |
| drop_path[i] if isinstance(drop_path, list) else drop_path, | |
| ) | |
| for i in range(depth)]) | |
| # patch merging layer | |
| if downsample is not None: | |
| self.downsample = downsample( | |
| input_resolution, dim=dim, out_dim=out_dim, activation=activation) | |
| else: | |
| self.downsample = None | |
| def forward(self, x): | |
| for blk in self.blocks: | |
| if self.use_checkpoint: | |
| x = checkpoint.checkpoint(blk, x) | |
| else: | |
| x = blk(x) | |
| if self.downsample is not None: | |
| x = self.downsample(x) | |
| return x | |
| class FasterTinyViT(nn.Module): | |
| def __init__(self, img_size=224, | |
| in_chans=3, | |
| out_chans=256, | |
| embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], | |
| num_heads=[3, 6, 12, 24], | |
| window_sizes=[7, 7, 14, 7], | |
| mlp_ratio=4., | |
| drop_rate=0., | |
| drop_path_rate=0.1, | |
| use_checkpoint=False, | |
| mbconv_expand_ratio=4.0, | |
| ct_size=2, | |
| conv=False, | |
| multi_scale=False, | |
| output_shape=None, | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.depths = depths | |
| self.num_layers = len(depths) | |
| self.mlp_ratio = mlp_ratio | |
| self.multi_scale = multi_scale | |
| self.output_shape = tuple(output_shape) if output_shape else None | |
| activation = nn.GELU | |
| self.patch_embed = PatchEmbed(in_chans=in_chans, | |
| embed_dim=embed_dims[0], | |
| resolution=img_size, | |
| activation=activation) | |
| patches_resolution = self.patch_embed.patches_resolution | |
| self.patches_resolution = patches_resolution | |
| # stochastic depth | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, | |
| sum(depths))] # stochastic depth decay rule | |
| # build layers | |
| self.layers = nn.ModuleList() | |
| for i_layer in range(self.num_layers): | |
| kwargs_0 = dict(dim=embed_dims[i_layer], | |
| input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), | |
| patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))), | |
| # input_resolution=(patches_resolution[0] // (2 ** i_layer), | |
| # patches_resolution[1] // (2 ** i_layer)), | |
| depth=depths[i_layer], | |
| drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |
| downsample=PatchMerging if ( | |
| i_layer < self.num_layers - 1) else None, | |
| use_checkpoint=use_checkpoint, | |
| out_dim=embed_dims[min( | |
| i_layer + 1, len(embed_dims) - 1)], | |
| activation=activation, | |
| ) | |
| kwargs_1 = dict(dim=embed_dims[i_layer], | |
| out_dim=embed_dims[i_layer+1] if ( | |
| i_layer < self.num_layers - 1) else embed_dims[i_layer], | |
| input_resolution=patches_resolution[0] // (2 ** i_layer), | |
| depth=depths[i_layer], | |
| drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |
| downsample=True if (i_layer < self.num_layers - 1) else False, | |
| ct_size=ct_size, | |
| conv=conv, | |
| ) | |
| if i_layer == 0: | |
| layer = ConvLayer( | |
| conv_expand_ratio=mbconv_expand_ratio, | |
| **kwargs_0, | |
| ) | |
| else: | |
| layer = FasterViTLayer( | |
| num_heads=num_heads[i_layer], | |
| window_size=window_sizes[i_layer], | |
| mlp_ratio=self.mlp_ratio, | |
| drop=drop_rate, | |
| **kwargs_1) | |
| self.layers.append(layer) | |
| # init weights | |
| self.apply(self._init_weights) | |
| self.neck = nn.Sequential( | |
| nn.Conv2d( | |
| sum(embed_dims)+embed_dims[-1] if self.multi_scale and self.output_shape else embed_dims[-1], | |
| out_chans, | |
| kernel_size=1, | |
| bias=False, | |
| ), | |
| LayerNorm2d(out_chans), | |
| nn.Conv2d( | |
| out_chans, | |
| out_chans, | |
| kernel_size=3, | |
| padding=1, | |
| bias=False, | |
| ), | |
| LayerNorm2d(out_chans), | |
| ) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def no_weight_decay_keywords(self): | |
| return {'attention_biases'} | |
| def forward_features(self, x): | |
| if self.multi_scale and self.output_shape: | |
| output_list = [] | |
| # x: (N, C, H, W) | |
| x = self.patch_embed(x) | |
| output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear')) | |
| for layer in self.layers: | |
| x = layer(x) | |
| output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear')) | |
| x = self.neck(torch.cat(output_list, dim=1)) | |
| else: | |
| x = self.patch_embed(x) | |
| for layer in self.layers: | |
| x = layer(x) | |
| x = self.neck(x) | |
| return x | |
| def forward(self, x): | |
| x = self.forward_features(x) | |
| return x | |
| if __name__ == '__main__': | |
| from distillation.utils import get_parameter_number | |
| x = torch.randn(1, 3, 1024, 1024).cuda() | |
| fastertinyvit = FasterTinyViT(img_size=1024, in_chans=3, | |
| embed_dims=[64, 128, 256], | |
| depths=[1, 2, 1], | |
| num_heads=[2, 4, 8], | |
| window_sizes=[8, 8, 8], | |
| mlp_ratio=4., | |
| drop_rate=0., | |
| drop_path_rate=0.0, | |
| use_checkpoint=False, | |
| mbconv_expand_ratio=4.0, | |
| multi_scale=False, | |
| output_shape='').cuda() | |
| print(fastertinyvit(x).shape) | |
| print(get_parameter_number(fastertinyvit)) | |
| # torch.save(fastertinyvit, 'fastertinyvit.pt') |