Spaces:
Build error
Build error
File size: 5,498 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version
from mmdet.registry import MODELS
if torch.__version__ == 'parrots':
TORCH_VERSION = torch.__version__
else:
# torch.__version__ could be 1.3.1+cu92, we only need the first two
# for comparison
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
def adaptive_avg_pool2d(input, output_size):
"""Handle empty batch dimension to adaptive_avg_pool2d.
Args:
input (tensor): 4D tensor.
output_size (int, tuple[int,int]): the target output size.
"""
if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
if isinstance(output_size, int):
output_size = [output_size, output_size]
output_size = [*input.shape[:2], *output_size]
empty = NewEmptyTensorOp.apply(input, output_size)
return empty
else:
return F.adaptive_avg_pool2d(input, output_size)
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
"""Handle empty batch dimension to AdaptiveAvgPool2d."""
def forward(self, x):
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
output_size = self.output_size
if isinstance(output_size, int):
output_size = [output_size, output_size]
else:
output_size = [
v if v is not None else d
for v, d in zip(output_size,
x.size()[-2:])
]
output_size = [*x.shape[:2], *output_size]
empty = NewEmptyTensorOp.apply(x, output_size)
return empty
return super().forward(x)
# Modified from
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py#L13 # noqa
@MODELS.register_module('FrozenBN')
class FrozenBatchNorm2d(nn.Module):
"""BatchNorm2d where the batch statistics and the affine parameters are
fixed.
It contains non-trainable buffers called
"weight" and "bias", "running_mean", "running_var",
initialized to perform identity transformation.
Args:
num_features (int): :math:`C` from an expected input of size
:math:`(N, C, H, W)`.
eps (float): a value added to the denominator for numerical stability.
Default: 1e-5
"""
def __init__(self, num_features, eps=1e-5, **kwargs):
super().__init__()
self.num_features = num_features
self.eps = eps
self.register_buffer('weight', torch.ones(num_features))
self.register_buffer('bias', torch.zeros(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features) - eps)
def forward(self, x):
if x.requires_grad:
# When gradients are needed, F.batch_norm will use extra memory
# because its backward op computes gradients for weight/bias
# as well.
scale = self.weight * (self.running_var + self.eps).rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
out_dtype = x.dtype # may be half
return x * scale.to(out_dtype) + bias.to(out_dtype)
else:
# When gradients are not needed, F.batch_norm is a single fused op
# and provide more optimization opportunities.
return F.batch_norm(
x,
self.running_mean,
self.running_var,
self.weight,
self.bias,
training=False,
eps=self.eps,
)
def __repr__(self):
return 'FrozenBatchNorm2d(num_features={}, eps={})'.format(
self.num_features, self.eps)
@classmethod
def convert_frozen_batchnorm(cls, module):
"""Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
Args:
module (torch.nn.Module):
Returns:
If module is BatchNorm/SyncBatchNorm, returns a new module.
Otherwise, in-place convert module and return it.
Similar to convert_sync_batchnorm in
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
"""
bn_module = nn.modules.batchnorm
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
res = module
if isinstance(module, bn_module):
res = cls(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = cls.convert_frozen_batchnorm(child)
if new_child is not child:
res.add_module(name, new_child)
return res
|