Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # File : batchnorm.py | |
| # Author : Jiayuan Mao | |
| # Email : [email protected] | |
| # Date : 27/01/2018 | |
| # | |
| # This file is part of Synchronized-BatchNorm-PyTorch. | |
| # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
| # Distributed under MIT License. | |
| import collections | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn.modules.batchnorm import _BatchNorm | |
| from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast | |
| from .comm import SyncMaster | |
| __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] | |
| def _sum_ft(tensor): | |
| """sum over the first and last dimention""" | |
| return tensor.sum(dim=0).sum(dim=-1) | |
| def _unsqueeze_ft(tensor): | |
| """add new dementions at the front and the tail""" | |
| return tensor.unsqueeze(0).unsqueeze(-1) | |
| _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) | |
| _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) | |
| class _SynchronizedBatchNorm(_BatchNorm): | |
| def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): | |
| super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) | |
| self._sync_master = SyncMaster(self._data_parallel_master) | |
| self._is_parallel = False | |
| self._parallel_id = None | |
| self._slave_pipe = None | |
| def forward(self, input): | |
| # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. | |
| if not (self._is_parallel and self.training): | |
| return F.batch_norm( | |
| input, self.running_mean, self.running_var, self.weight, self.bias, | |
| self.training, self.momentum, self.eps) | |
| # Resize the input to (B, C, -1). | |
| input_shape = input.size() | |
| input = input.view(input.size(0), self.num_features, -1) | |
| # Compute the sum and square-sum. | |
| sum_size = input.size(0) * input.size(2) | |
| input_sum = _sum_ft(input) | |
| input_ssum = _sum_ft(input ** 2) | |
| # Reduce-and-broadcast the statistics. | |
| if self._parallel_id == 0: | |
| mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) | |
| else: | |
| mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) | |
| # Compute the output. | |
| if self.affine: | |
| # MJY:: Fuse the multiplication for speed. | |
| output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) | |
| else: | |
| output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) | |
| # Reshape it. | |
| return output.view(input_shape) | |
| def __data_parallel_replicate__(self, ctx, copy_id): | |
| self._is_parallel = True | |
| self._parallel_id = copy_id | |
| # parallel_id == 0 means master device. | |
| if self._parallel_id == 0: | |
| ctx.sync_master = self._sync_master | |
| else: | |
| self._slave_pipe = ctx.sync_master.register_slave(copy_id) | |
| def _data_parallel_master(self, intermediates): | |
| """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" | |
| # Always using same "device order" makes the ReduceAdd operation faster. | |
| # Thanks to:: Tete Xiao (http://tetexiao.com/) | |
| intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) | |
| to_reduce = [i[1][:2] for i in intermediates] | |
| to_reduce = [j for i in to_reduce for j in i] # flatten | |
| target_gpus = [i[1].sum.get_device() for i in intermediates] | |
| sum_size = sum([i[1].sum_size for i in intermediates]) | |
| sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) | |
| mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) | |
| broadcasted = Broadcast.apply(target_gpus, mean, inv_std) | |
| outputs = [] | |
| for i, rec in enumerate(intermediates): | |
| outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) | |
| return outputs | |
| def _compute_mean_std(self, sum_, ssum, size): | |
| """Compute the mean and standard-deviation with sum and square-sum. This method | |
| also maintains the moving average on the master device.""" | |
| assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' | |
| mean = sum_ / size | |
| sumvar = ssum - sum_ * mean | |
| unbias_var = sumvar / (size - 1) | |
| bias_var = sumvar / size | |
| self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data | |
| self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data | |
| return mean, bias_var.clamp(self.eps) ** -0.5 | |
| class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): | |
| r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a | |
| mini-batch. | |
| .. math:: | |
| y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta | |
| This module differs from the built-in PyTorch BatchNorm1d as the mean and | |
| standard-deviation are reduced across all devices during training. | |
| For example, when one uses `nn.DataParallel` to wrap the network during | |
| training, PyTorch's implementation normalize the tensor on each device using | |
| the statistics only on that device, which accelerated the computation and | |
| is also easy to implement, but the statistics might be inaccurate. | |
| Instead, in this synchronized version, the statistics will be computed | |
| over all training samples distributed on multiple devices. | |
| Note that, for one-GPU or CPU-only case, this module behaves exactly same | |
| as the built-in PyTorch implementation. | |
| The mean and standard-deviation are calculated per-dimension over | |
| the mini-batches and gamma and beta are learnable parameter vectors | |
| of size C (where C is the input size). | |
| During training, this layer keeps a running estimate of its computed mean | |
| and variance. The running sum is kept with a default momentum of 0.1. | |
| During evaluation, this running mean/variance is used for normalization. | |
| Because the BatchNorm is done over the `C` dimension, computing statistics | |
| on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm | |
| Args: | |
| num_features: num_features from an expected input of size | |
| `batch_size x num_features [x width]` | |
| eps: a value added to the denominator for numerical stability. | |
| Default: 1e-5 | |
| momentum: the value used for the running_mean and running_var | |
| computation. Default: 0.1 | |
| affine: a boolean value that when set to ``True``, gives the layer learnable | |
| affine parameters. Default: ``True`` | |
| Shape: | |
| - Input: :math:`(N, C)` or :math:`(N, C, L)` | |
| - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) | |
| Examples: | |
| >>> # With Learnable Parameters | |
| >>> m = SynchronizedBatchNorm1d(100) | |
| >>> # Without Learnable Parameters | |
| >>> m = SynchronizedBatchNorm1d(100, affine=False) | |
| >>> input = torch.autograd.Variable(torch.randn(20, 100)) | |
| >>> output = m(input) | |
| """ | |
| def _check_input_dim(self, input): | |
| if input.dim() != 2 and input.dim() != 3: | |
| raise ValueError('expected 2D or 3D input (got {}D input)' | |
| .format(input.dim())) | |
| super(SynchronizedBatchNorm1d, self)._check_input_dim(input) | |
| class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): | |
| r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch | |
| of 3d inputs | |
| .. math:: | |
| y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta | |
| This module differs from the built-in PyTorch BatchNorm2d as the mean and | |
| standard-deviation are reduced across all devices during training. | |
| For example, when one uses `nn.DataParallel` to wrap the network during | |
| training, PyTorch's implementation normalize the tensor on each device using | |
| the statistics only on that device, which accelerated the computation and | |
| is also easy to implement, but the statistics might be inaccurate. | |
| Instead, in this synchronized version, the statistics will be computed | |
| over all training samples distributed on multiple devices. | |
| Note that, for one-GPU or CPU-only case, this module behaves exactly same | |
| as the built-in PyTorch implementation. | |
| The mean and standard-deviation are calculated per-dimension over | |
| the mini-batches and gamma and beta are learnable parameter vectors | |
| of size C (where C is the input size). | |
| During training, this layer keeps a running estimate of its computed mean | |
| and variance. The running sum is kept with a default momentum of 0.1. | |
| During evaluation, this running mean/variance is used for normalization. | |
| Because the BatchNorm is done over the `C` dimension, computing statistics | |
| on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm | |
| Args: | |
| num_features: num_features from an expected input of | |
| size batch_size x num_features x height x width | |
| eps: a value added to the denominator for numerical stability. | |
| Default: 1e-5 | |
| momentum: the value used for the running_mean and running_var | |
| computation. Default: 0.1 | |
| affine: a boolean value that when set to ``True``, gives the layer learnable | |
| affine parameters. Default: ``True`` | |
| Shape: | |
| - Input: :math:`(N, C, H, W)` | |
| - Output: :math:`(N, C, H, W)` (same shape as input) | |
| Examples: | |
| >>> # With Learnable Parameters | |
| >>> m = SynchronizedBatchNorm2d(100) | |
| >>> # Without Learnable Parameters | |
| >>> m = SynchronizedBatchNorm2d(100, affine=False) | |
| >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) | |
| >>> output = m(input) | |
| """ | |
| def _check_input_dim(self, input): | |
| if input.dim() != 4: | |
| raise ValueError('expected 4D input (got {}D input)' | |
| .format(input.dim())) | |
| super(SynchronizedBatchNorm2d, self)._check_input_dim(input) | |
| class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): | |
| r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch | |
| of 4d inputs | |
| .. math:: | |
| y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta | |
| This module differs from the built-in PyTorch BatchNorm3d as the mean and | |
| standard-deviation are reduced across all devices during training. | |
| For example, when one uses `nn.DataParallel` to wrap the network during | |
| training, PyTorch's implementation normalize the tensor on each device using | |
| the statistics only on that device, which accelerated the computation and | |
| is also easy to implement, but the statistics might be inaccurate. | |
| Instead, in this synchronized version, the statistics will be computed | |
| over all training samples distributed on multiple devices. | |
| Note that, for one-GPU or CPU-only case, this module behaves exactly same | |
| as the built-in PyTorch implementation. | |
| The mean and standard-deviation are calculated per-dimension over | |
| the mini-batches and gamma and beta are learnable parameter vectors | |
| of size C (where C is the input size). | |
| During training, this layer keeps a running estimate of its computed mean | |
| and variance. The running sum is kept with a default momentum of 0.1. | |
| During evaluation, this running mean/variance is used for normalization. | |
| Because the BatchNorm is done over the `C` dimension, computing statistics | |
| on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm | |
| or Spatio-temporal BatchNorm | |
| Args: | |
| num_features: num_features from an expected input of | |
| size batch_size x num_features x depth x height x width | |
| eps: a value added to the denominator for numerical stability. | |
| Default: 1e-5 | |
| momentum: the value used for the running_mean and running_var | |
| computation. Default: 0.1 | |
| affine: a boolean value that when set to ``True``, gives the layer learnable | |
| affine parameters. Default: ``True`` | |
| Shape: | |
| - Input: :math:`(N, C, D, H, W)` | |
| - Output: :math:`(N, C, D, H, W)` (same shape as input) | |
| Examples: | |
| >>> # With Learnable Parameters | |
| >>> m = SynchronizedBatchNorm3d(100) | |
| >>> # Without Learnable Parameters | |
| >>> m = SynchronizedBatchNorm3d(100, affine=False) | |
| >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) | |
| >>> output = m(input) | |
| """ | |
| def _check_input_dim(self, input): | |
| if input.dim() != 5: | |
| raise ValueError('expected 5D input (got {}D input)' | |
| .format(input.dim())) | |
| super(SynchronizedBatchNorm3d, self)._check_input_dim(input) | |