Spaces:
Runtime error
Runtime error
File size: 15,753 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmengine.model import BaseModule
from .helpers import to_2tuple
def resize_pos_embed(pos_embed,
src_shape,
dst_shape,
mode='bicubic',
num_extra_tokens=1):
"""Resize pos_embed weights.
Args:
pos_embed (torch.Tensor): Position embedding weights with shape
[1, L, C].
src_shape (tuple): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (tuple): The resolution of downsampled new training
image, in format (H, W).
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
'linear', 'bilinear', 'bicubic' and 'trilinear'.
Defaults to 'bicubic'.
num_extra_tokens (int): The number of extra tokens, such as cls_token.
Defaults to 1.
Returns:
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
"""
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
return pos_embed
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
_, L, C = pos_embed.shape
src_h, src_w = src_shape
assert L == src_h * src_w + num_extra_tokens, \
f"The length of `pos_embed` ({L}) doesn't match the expected " \
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
'`img_size` argument.'
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
# The cubic interpolate algorithm only accepts float32
dst_weight = F.interpolate(
src_weight.float(), size=dst_shape, align_corners=False, mode=mode)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
dst_weight = dst_weight.to(src_weight.dtype)
return torch.cat((extra_tokens, dst_weight), dim=1)
def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head):
"""Resize relative position bias table.
Args:
src_shape (int): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (int): The resolution of downsampled new training
image, in format (H, W).
table (tensor): The relative position bias of the pretrained model.
num_head (int): Number of attention heads.
Returns:
torch.Tensor: The resized relative position bias table.
"""
from scipy import interpolate
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_shape // 2)
if gp > dst_shape // 2:
right = q
else:
left = q
dis = []
cur = 1
for i in range(src_shape // 2):
dis.append(cur)
cur += q**(i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_shape // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
all_rel_pos_bias = []
for i in range(num_head):
z = table[:, i].view(src_shape, src_shape).float().numpy()
f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f_cubic(dx,
dy)).contiguous().view(-1,
1).to(table.device))
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
return new_rel_pos_bias
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
img_size (int | tuple): The size of input image. Default: 224
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None
conv_cfg (dict, optional): The config dict for conv layers.
Default: None
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None
"""
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=768,
norm_cfg=None,
conv_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__(init_cfg)
warnings.warn('The `PatchEmbed` in mmpretrain will be deprecated. '
'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. '
"It's more general and supports dynamic input shape")
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.embed_dims = embed_dims
# Use conv layer to embed
conv_cfg = conv_cfg or dict()
_conv_cfg = dict(
type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims)
# Calculate how many patches a input image is splited to.
h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] -
self.projection.dilation[i] *
(self.projection.kernel_size[i] - 1) - 1) //
self.projection.stride[i] + 1 for i in range(2)]
self.patches_resolution = (h_out, w_out)
self.num_patches = h_out * w_out
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't " \
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
x = self.projection(x).flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
# Modified from pytorch-image-models
class HybridEmbed(BaseModule):
"""CNN Feature Map Embedding.
Extract feature map from CNN, flatten,
project to embedding dim.
Args:
backbone (nn.Module): CNN backbone
img_size (int | tuple): The size of input image. Default: 224
feature_size (int | tuple, optional): Size of feature map extracted by
CNN backbone. Default: None
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_cfg (dict, optional): The config dict for conv layers.
Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
backbone,
img_size=224,
feature_size=None,
in_channels=3,
embed_dims=768,
conv_cfg=None,
init_cfg=None):
super(HybridEmbed, self).__init__(init_cfg)
assert isinstance(backbone, nn.Module)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of
# determining the exact dim of the output feature
# map for all networks, the feature metadata has
# reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of
# each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(
torch.zeros(1, in_channels, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
# last feature if backbone outputs list/tuple of features
o = o[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
# Use conv layer to embed
conv_cfg = conv_cfg or dict()
_conv_cfg = dict(
type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
# last feature if backbone outputs list/tuple of features
x = x[-1]
x = self.projection(x).flatten(2).transpose(1, 2)
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
Modified from mmcv, and this module supports specifying whether to use
post-norm.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map ((used in Swin Transformer)). Our
implementation uses :class:`torch.nn.Unfold` to merge patches, which is
about 25% faster than the original implementation. However, we need to
modify pretrained models for compatibility.
Args:
in_channels (int): The num of input channels. To gets fully covered
by filter and stride you specified.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Defaults to None, which means to be set as
``kernel_size``.
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Defaults to "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Defaults to 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
use_post_norm (bool): Whether to use post normalization here.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
use_post_norm=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.use_post_norm = use_post_norm
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
if norm_cfg is not None:
# build pre or post norm layer based on different channels
if self.use_post_norm:
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
else:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
if self.adaptive_padding:
x = self.adaptive_padding(x)
H, W = x.shape[-2:]
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
x = self.sampler(x)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
if self.use_post_norm:
# use post-norm here
x = self.reduction(x)
x = self.norm(x) if self.norm else x
else:
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size
|