Spaces:
Running
Running
from typing import Optional, Any | |
import torch | |
from torch import Tensor | |
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer | |
from .. import functional as F | |
from .. import init | |
from ._functions import SyncBatchNorm as sync_batch_norm | |
from .lazy import LazyModuleMixin | |
from .module import Module | |
__all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d', | |
'LazyBatchNorm3d', 'SyncBatchNorm'] | |
class _NormBase(Module): | |
"""Common base of _InstanceNorm and _BatchNorm.""" | |
_version = 2 | |
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] | |
num_features: int | |
eps: float | |
momentum: float | |
affine: bool | |
track_running_stats: bool | |
# WARNING: weight and bias purposely not defined here. | |
# See https://github.com/pytorch/pytorch/issues/39670 | |
def __init__( | |
self, | |
num_features: int, | |
eps: float = 1e-5, | |
momentum: float = 0.1, | |
affine: bool = True, | |
track_running_stats: bool = True, | |
device=None, | |
dtype=None | |
) -> None: | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.momentum = momentum | |
self.affine = affine | |
self.track_running_stats = track_running_stats | |
if self.affine: | |
self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) | |
self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) | |
else: | |
self.register_parameter("weight", None) | |
self.register_parameter("bias", None) | |
if self.track_running_stats: | |
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) | |
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) | |
self.running_mean: Optional[Tensor] | |
self.running_var: Optional[Tensor] | |
self.register_buffer('num_batches_tracked', | |
torch.tensor(0, dtype=torch.long, | |
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'})) | |
self.num_batches_tracked: Optional[Tensor] | |
else: | |
self.register_buffer("running_mean", None) | |
self.register_buffer("running_var", None) | |
self.register_buffer("num_batches_tracked", None) | |
self.reset_parameters() | |
def reset_running_stats(self) -> None: | |
if self.track_running_stats: | |
# running_mean/running_var/num_batches... are registered at runtime depending | |
# if self.track_running_stats is on | |
self.running_mean.zero_() # type: ignore[union-attr] | |
self.running_var.fill_(1) # type: ignore[union-attr] | |
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] | |
def reset_parameters(self) -> None: | |
self.reset_running_stats() | |
if self.affine: | |
init.ones_(self.weight) | |
init.zeros_(self.bias) | |
def _check_input_dim(self, input): | |
raise NotImplementedError | |
def extra_repr(self): | |
return ( | |
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " | |
"track_running_stats={track_running_stats}".format(**self.__dict__) | |
) | |
def _load_from_state_dict( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
version = local_metadata.get("version", None) | |
if (version is None or version < 2) and self.track_running_stats: | |
# at version 2: added num_batches_tracked buffer | |
# this should have a default value of 0 | |
num_batches_tracked_key = prefix + "num_batches_tracked" | |
if num_batches_tracked_key not in state_dict: | |
state_dict[num_batches_tracked_key] = ( | |
self.num_batches_tracked | |
if self.num_batches_tracked is not None and self.num_batches_tracked.device != torch.device('meta') | |
else torch.tensor(0, dtype=torch.long) | |
) | |
super()._load_from_state_dict( | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
) | |
class _BatchNorm(_NormBase): | |
def __init__( | |
self, | |
num_features: int, | |
eps: float = 1e-5, | |
momentum: float = 0.1, | |
affine: bool = True, | |
track_running_stats: bool = True, | |
device=None, | |
dtype=None | |
) -> None: | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
super().__init__( | |
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs | |
) | |
def forward(self, input: Tensor) -> Tensor: | |
self._check_input_dim(input) | |
# exponential_average_factor is set to self.momentum | |
# (when it is available) only so that it gets updated | |
# in ONNX graph when this node is exported to ONNX. | |
if self.momentum is None: | |
exponential_average_factor = 0.0 | |
else: | |
exponential_average_factor = self.momentum | |
if self.training and self.track_running_stats: | |
# TODO: if statement only here to tell the jit to skip emitting this when it is None | |
if self.num_batches_tracked is not None: # type: ignore[has-type] | |
self.num_batches_tracked.add_(1) # type: ignore[has-type] | |
if self.momentum is None: # use cumulative moving average | |
exponential_average_factor = 1.0 / float(self.num_batches_tracked) | |
else: # use exponential moving average | |
exponential_average_factor = self.momentum | |
r""" | |
Decide whether the mini-batch stats should be used for normalization rather than the buffers. | |
Mini-batch stats are used in training mode, and in eval mode when buffers are None. | |
""" | |
if self.training: | |
bn_training = True | |
else: | |
bn_training = (self.running_mean is None) and (self.running_var is None) | |
r""" | |
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be | |
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are | |
used for normalization (i.e. in eval mode when buffers are not None). | |
""" | |
return F.batch_norm( | |
input, | |
# If buffers are not to be tracked, ensure that they won't be updated | |
self.running_mean | |
if not self.training or self.track_running_stats | |
else None, | |
self.running_var if not self.training or self.track_running_stats else None, | |
self.weight, | |
self.bias, | |
bn_training, | |
exponential_average_factor, | |
self.eps, | |
) | |
class _LazyNormBase(LazyModuleMixin, _NormBase): | |
weight: UninitializedParameter # type: ignore[assignment] | |
bias: UninitializedParameter # type: ignore[assignment] | |
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, | |
device=None, dtype=None) -> None: | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
super().__init__( | |
# affine and track_running_stats are hardcoded to False to | |
# avoid creating tensors that will soon be overwritten. | |
0, | |
eps, | |
momentum, | |
False, | |
False, | |
**factory_kwargs, | |
) | |
self.affine = affine | |
self.track_running_stats = track_running_stats | |
if self.affine: | |
self.weight = UninitializedParameter(**factory_kwargs) | |
self.bias = UninitializedParameter(**factory_kwargs) | |
if self.track_running_stats: | |
self.running_mean = UninitializedBuffer(**factory_kwargs) | |
self.running_var = UninitializedBuffer(**factory_kwargs) | |
self.num_batches_tracked = torch.tensor( | |
0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) | |
def reset_parameters(self) -> None: | |
if not self.has_uninitialized_params() and self.num_features != 0: | |
super().reset_parameters() | |
def initialize_parameters(self, input) -> None: # type: ignore[override] | |
if self.has_uninitialized_params(): | |
self.num_features = input.shape[1] | |
if self.affine: | |
assert isinstance(self.weight, UninitializedParameter) | |
assert isinstance(self.bias, UninitializedParameter) | |
self.weight.materialize((self.num_features,)) | |
self.bias.materialize((self.num_features,)) | |
if self.track_running_stats: | |
self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr] | |
self.running_var.materialize((self.num_features,)) # type:ignore[union-attr] | |
self.reset_parameters() | |
class BatchNorm1d(_BatchNorm): | |
r"""Applies Batch Normalization over a 2D or 3D input. | |
Method described in the paper | |
`Batch Normalization: Accelerating Deep Network Training by Reducing | |
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . | |
.. math:: | |
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |
The mean and standard-deviation are calculated per-dimension over | |
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors | |
of size `C` (where `C` is the number of features or channels of the input). By default, the | |
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. | |
At train time in the forward pass, the standard-deviation is calculated via the biased estimator, | |
equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the | |
moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to | |
``torch.var(input, unbiased=True)``. | |
Also by default, during training this layer keeps running estimates of its | |
computed mean and variance, which are then used for normalization during | |
evaluation. The running estimates are kept with a default :attr:`momentum` | |
of 0.1. | |
If :attr:`track_running_stats` is set to ``False``, this layer then does not | |
keep running estimates, and batch statistics are instead used during | |
evaluation time as well. | |
.. note:: | |
This :attr:`momentum` argument is different from one used in optimizer | |
classes and the conventional notion of momentum. Mathematically, the | |
update rule for running statistics here is | |
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, | |
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |
new observed value. | |
Because the Batch Normalization is done over the `C` dimension, computing statistics | |
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. | |
Args: | |
num_features: number of features or channels :math:`C` of the input | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine: a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats: a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics, and initializes statistics | |
buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | |
When these buffers are ``None``, this module always uses batch statistics. | |
in both training and eval modes. Default: ``True`` | |
Shape: | |
- Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, | |
:math:`C` is the number of features or channels, and :math:`L` is the sequence length | |
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) | |
Examples:: | |
>>> # With Learnable Parameters | |
>>> m = nn.BatchNorm1d(100) | |
>>> # Without Learnable Parameters | |
>>> m = nn.BatchNorm1d(100, affine=False) | |
>>> input = torch.randn(20, 100) | |
>>> output = m(input) | |
""" | |
def _check_input_dim(self, input): | |
if input.dim() != 2 and input.dim() != 3: | |
raise ValueError( | |
f"expected 2D or 3D input (got {input.dim()}D input)" | |
) | |
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): | |
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization. | |
Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred | |
from the ``input.size(1)``. | |
The attributes that will be lazily initialized are `weight`, `bias`, | |
`running_mean` and `running_var`. | |
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation | |
on lazy modules and their limitations. | |
Args: | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine: a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats: a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics, and initializes statistics | |
buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | |
When these buffers are ``None``, this module always uses batch statistics. | |
in both training and eval modes. Default: ``True`` | |
""" | |
cls_to_become = BatchNorm1d # type: ignore[assignment] | |
def _check_input_dim(self, input): | |
if input.dim() != 2 and input.dim() != 3: | |
raise ValueError( | |
f"expected 2D or 3D input (got {input.dim()}D input)" | |
) | |
class BatchNorm2d(_BatchNorm): | |
r"""Applies Batch Normalization over a 4D input. | |
4D is a mini-batch of 2D inputs | |
with additional channel dimension. Method described in the paper | |
`Batch Normalization: Accelerating Deep Network Training by Reducing | |
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . | |
.. math:: | |
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |
The mean and standard-deviation are calculated per-dimension over | |
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors | |
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set | |
to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the | |
standard-deviation is calculated via the biased estimator, equivalent to | |
``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the | |
standard-deviation is calculated via the unbiased estimator, equivalent to | |
``torch.var(input, unbiased=True)``. | |
Also by default, during training this layer keeps running estimates of its | |
computed mean and variance, which are then used for normalization during | |
evaluation. The running estimates are kept with a default :attr:`momentum` | |
of 0.1. | |
If :attr:`track_running_stats` is set to ``False``, this layer then does not | |
keep running estimates, and batch statistics are instead used during | |
evaluation time as well. | |
.. note:: | |
This :attr:`momentum` argument is different from one used in optimizer | |
classes and the conventional notion of momentum. Mathematically, the | |
update rule for running statistics here is | |
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, | |
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |
new observed value. | |
Because the Batch Normalization is done over the `C` dimension, computing statistics | |
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. | |
Args: | |
num_features: :math:`C` from an expected input of size | |
:math:`(N, C, H, W)` | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine: a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats: a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics, and initializes statistics | |
buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | |
When these buffers are ``None``, this module always uses batch statistics. | |
in both training and eval modes. Default: ``True`` | |
Shape: | |
- Input: :math:`(N, C, H, W)` | |
- Output: :math:`(N, C, H, W)` (same shape as input) | |
Examples:: | |
>>> # With Learnable Parameters | |
>>> m = nn.BatchNorm2d(100) | |
>>> # Without Learnable Parameters | |
>>> m = nn.BatchNorm2d(100, affine=False) | |
>>> input = torch.randn(20, 100, 35, 45) | |
>>> output = m(input) | |
""" | |
def _check_input_dim(self, input): | |
if input.dim() != 4: | |
raise ValueError(f"expected 4D input (got {input.dim()}D input)") | |
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): | |
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization. | |
Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred | |
from the ``input.size(1)``. | |
The attributes that will be lazily initialized are `weight`, `bias`, | |
`running_mean` and `running_var`. | |
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation | |
on lazy modules and their limitations. | |
Args: | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine: a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats: a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics, and initializes statistics | |
buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | |
When these buffers are ``None``, this module always uses batch statistics. | |
in both training and eval modes. Default: ``True`` | |
""" | |
cls_to_become = BatchNorm2d # type: ignore[assignment] | |
def _check_input_dim(self, input): | |
if input.dim() != 4: | |
raise ValueError(f"expected 4D input (got {input.dim()}D input)") | |
class BatchNorm3d(_BatchNorm): | |
r"""Applies Batch Normalization over a 5D input. | |
5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper | |
`Batch Normalization: Accelerating Deep Network Training by Reducing | |
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . | |
.. math:: | |
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |
The mean and standard-deviation are calculated per-dimension over | |
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors | |
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set | |
to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the | |
standard-deviation is calculated via the biased estimator, equivalent to | |
``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the | |
standard-deviation is calculated via the unbiased estimator, equivalent to | |
``torch.var(input, unbiased=True)``. | |
Also by default, during training this layer keeps running estimates of its | |
computed mean and variance, which are then used for normalization during | |
evaluation. The running estimates are kept with a default :attr:`momentum` | |
of 0.1. | |
If :attr:`track_running_stats` is set to ``False``, this layer then does not | |
keep running estimates, and batch statistics are instead used during | |
evaluation time as well. | |
.. note:: | |
This :attr:`momentum` argument is different from one used in optimizer | |
classes and the conventional notion of momentum. Mathematically, the | |
update rule for running statistics here is | |
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, | |
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |
new observed value. | |
Because the Batch Normalization is done over the `C` dimension, computing statistics | |
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization | |
or Spatio-temporal Batch Normalization. | |
Args: | |
num_features: :math:`C` from an expected input of size | |
:math:`(N, C, D, H, W)` | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine: a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats: a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics, and initializes statistics | |
buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | |
When these buffers are ``None``, this module always uses batch statistics. | |
in both training and eval modes. Default: ``True`` | |
Shape: | |
- Input: :math:`(N, C, D, H, W)` | |
- Output: :math:`(N, C, D, H, W)` (same shape as input) | |
Examples:: | |
>>> # With Learnable Parameters | |
>>> m = nn.BatchNorm3d(100) | |
>>> # Without Learnable Parameters | |
>>> m = nn.BatchNorm3d(100, affine=False) | |
>>> input = torch.randn(20, 100, 35, 45, 10) | |
>>> output = m(input) | |
""" | |
def _check_input_dim(self, input): | |
if input.dim() != 5: | |
raise ValueError(f"expected 5D input (got {input.dim()}D input)") | |
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): | |
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization. | |
Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred | |
from the ``input.size(1)``. | |
The attributes that will be lazily initialized are `weight`, `bias`, | |
`running_mean` and `running_var`. | |
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation | |
on lazy modules and their limitations. | |
Args: | |
eps: a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
momentum: the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine: a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats: a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics, and initializes statistics | |
buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | |
When these buffers are ``None``, this module always uses batch statistics. | |
in both training and eval modes. Default: ``True`` | |
""" | |
cls_to_become = BatchNorm3d # type: ignore[assignment] | |
def _check_input_dim(self, input): | |
if input.dim() != 5: | |
raise ValueError(f"expected 5D input (got {input.dim()}D input)") | |
class SyncBatchNorm(_BatchNorm): | |
r"""Applies Batch Normalization over a N-Dimensional input. | |
The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper | |
`Batch Normalization: Accelerating Deep Network Training by Reducing | |
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . | |
.. math:: | |
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |
The mean and standard-deviation are calculated per-dimension over all | |
mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` | |
are learnable parameter vectors of size `C` (where `C` is the input size). | |
By default, the elements of :math:`\gamma` are sampled from | |
:math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. | |
The standard-deviation is calculated via the biased estimator, equivalent to | |
`torch.var(input, unbiased=False)`. | |
Also by default, during training this layer keeps running estimates of its | |
computed mean and variance, which are then used for normalization during | |
evaluation. The running estimates are kept with a default :attr:`momentum` | |
of 0.1. | |
If :attr:`track_running_stats` is set to ``False``, this layer then does not | |
keep running estimates, and batch statistics are instead used during | |
evaluation time as well. | |
.. note:: | |
This :attr:`momentum` argument is different from one used in optimizer | |
classes and the conventional notion of momentum. Mathematically, the | |
update rule for running statistics here is | |
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, | |
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | |
new observed value. | |
Because the Batch Normalization is done for each channel in the ``C`` dimension, computing | |
statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch | |
Normalization or Spatio-temporal Batch Normalization. | |
Currently :class:`SyncBatchNorm` only supports | |
:class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use | |
:meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert | |
:attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping | |
Network with DDP. | |
Args: | |
num_features: :math:`C` from an expected input of size | |
:math:`(N, C, +)` | |
eps: a value added to the denominator for numerical stability. | |
Default: ``1e-5`` | |
momentum: the value used for the running_mean and running_var | |
computation. Can be set to ``None`` for cumulative moving average | |
(i.e. simple average). Default: 0.1 | |
affine: a boolean value that when set to ``True``, this module has | |
learnable affine parameters. Default: ``True`` | |
track_running_stats: a boolean value that when set to ``True``, this | |
module tracks the running mean and variance, and when set to ``False``, | |
this module does not track such statistics, and initializes statistics | |
buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | |
When these buffers are ``None``, this module always uses batch statistics. | |
in both training and eval modes. Default: ``True`` | |
process_group: synchronization of stats happen within each process group | |
individually. Default behavior is synchronization across the whole | |
world | |
Shape: | |
- Input: :math:`(N, C, +)` | |
- Output: :math:`(N, C, +)` (same shape as input) | |
.. note:: | |
Synchronization of batchnorm statistics occurs only while training, i.e. | |
synchronization is disabled when ``model.eval()`` is set or if | |
``self.training`` is otherwise ``False``. | |
Examples:: | |
>>> # xdoctest: +SKIP | |
>>> # With Learnable Parameters | |
>>> m = nn.SyncBatchNorm(100) | |
>>> # creating process group (optional) | |
>>> # ranks is a list of int identifying rank ids. | |
>>> ranks = list(range(8)) | |
>>> r1, r2 = ranks[:4], ranks[4:] | |
>>> # Note: every rank calls into new_group for every | |
>>> # process group created, even if that rank is not | |
>>> # part of the group. | |
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] | |
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] | |
>>> # Without Learnable Parameters | |
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) | |
>>> input = torch.randn(20, 100, 35, 45, 10) | |
>>> output = m(input) | |
>>> # network is nn.BatchNorm layer | |
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) | |
>>> # only single gpu per process is currently supported | |
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( | |
>>> sync_bn_network, | |
>>> device_ids=[args.local_rank], | |
>>> output_device=args.local_rank) | |
""" | |
def __init__( | |
self, | |
num_features: int, | |
eps: float = 1e-5, | |
momentum: float = 0.1, | |
affine: bool = True, | |
track_running_stats: bool = True, | |
process_group: Optional[Any] = None, | |
device=None, | |
dtype=None | |
) -> None: | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
super().__init__( | |
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs | |
) | |
self.process_group = process_group | |
def _check_input_dim(self, input): | |
if input.dim() < 2: | |
raise ValueError( | |
f"expected at least 2D input (got {input.dim()}D input)" | |
) | |
def _check_non_zero_input_channels(self, input): | |
if input.size(1) == 0: | |
raise ValueError( | |
"SyncBatchNorm number of input channels should be non-zero" | |
) | |
def forward(self, input: Tensor) -> Tensor: | |
self._check_input_dim(input) | |
self._check_non_zero_input_channels(input) | |
# exponential_average_factor is set to self.momentum | |
# (when it is available) only so that it gets updated | |
# in ONNX graph when this node is exported to ONNX. | |
if self.momentum is None: | |
exponential_average_factor = 0.0 | |
else: | |
exponential_average_factor = self.momentum | |
if self.training and self.track_running_stats: | |
assert self.num_batches_tracked is not None | |
self.num_batches_tracked.add_(1) | |
if self.momentum is None: # use cumulative moving average | |
exponential_average_factor = 1.0 / self.num_batches_tracked.item() | |
else: # use exponential moving average | |
exponential_average_factor = self.momentum | |
r""" | |
Decide whether the mini-batch stats should be used for normalization rather than the buffers. | |
Mini-batch stats are used in training mode, and in eval mode when buffers are None. | |
""" | |
if self.training: | |
bn_training = True | |
else: | |
bn_training = (self.running_mean is None) and (self.running_var is None) | |
r""" | |
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be | |
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are | |
used for normalization (i.e. in eval mode when buffers are not None). | |
""" | |
# If buffers are not to be tracked, ensure that they won't be updated | |
running_mean = ( | |
self.running_mean if not self.training or self.track_running_stats else None | |
) | |
running_var = ( | |
self.running_var if not self.training or self.track_running_stats else None | |
) | |
# Don't sync batchnorm stats in inference mode (model.eval()). | |
need_sync = (bn_training and self.training and | |
torch.distributed.is_available() and torch.distributed.is_initialized()) | |
if need_sync: | |
# currently only GPU/PrivateUse1 input is supported | |
if input.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]: | |
raise ValueError("SyncBatchNorm expected input tensor to be on GPU or " | |
f"{torch._C._get_privateuse1_backend_name()}") | |
process_group = torch.distributed.group.WORLD | |
if self.process_group: | |
process_group = self.process_group | |
world_size = torch.distributed.get_world_size(process_group) | |
need_sync = world_size > 1 | |
# fallback to framework BN when synchronization is not necessary | |
if not need_sync: | |
return F.batch_norm( | |
input, | |
running_mean, | |
running_var, | |
self.weight, | |
self.bias, | |
bn_training, | |
exponential_average_factor, | |
self.eps, | |
) | |
else: | |
assert bn_training | |
return sync_batch_norm.apply( | |
input, | |
self.weight, | |
self.bias, | |
running_mean, | |
running_var, | |
self.eps, | |
exponential_average_factor, | |
process_group, # type: ignore[possibly-undefined] | |
world_size, # type: ignore[possibly-undefined] | |
) | |
def convert_sync_batchnorm(cls, module, process_group=None): | |
r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers. | |
Args: | |
module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers | |
process_group (optional): process group to scope synchronization, | |
default is the whole world | |
Returns: | |
The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm` | |
layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, | |
a new :class:`torch.nn.SyncBatchNorm` layer object will be returned | |
instead. | |
Example:: | |
>>> # Network with nn.BatchNorm layer | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) | |
>>> module = torch.nn.Sequential( | |
>>> torch.nn.Linear(20, 100), | |
>>> torch.nn.BatchNorm1d(100), | |
>>> ).cuda() | |
>>> # creating process group (optional) | |
>>> # ranks is a list of int identifying rank ids. | |
>>> ranks = list(range(8)) | |
>>> r1, r2 = ranks[:4], ranks[4:] | |
>>> # Note: every rank calls into new_group for every | |
>>> # process group created, even if that rank is not | |
>>> # part of the group. | |
>>> # xdoctest: +SKIP("distributed") | |
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] | |
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] | |
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) | |
""" | |
module_output = module | |
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): | |
module_output = torch.nn.SyncBatchNorm( | |
module.num_features, | |
module.eps, | |
module.momentum, | |
module.affine, | |
module.track_running_stats, | |
process_group, | |
) | |
if module.affine: | |
with torch.no_grad(): | |
module_output.weight = module.weight | |
module_output.bias = module.bias | |
module_output.running_mean = module.running_mean | |
module_output.running_var = module.running_var | |
module_output.num_batches_tracked = module.num_batches_tracked | |
module_output.training = module.training | |
if hasattr(module, "qconfig"): | |
module_output.qconfig = module.qconfig | |
for name, child in module.named_children(): | |
module_output.add_module( | |
name, cls.convert_sync_batchnorm(child, process_group) | |
) | |
del module | |
return module_output | |