|
import torch.nn as nn |
|
|
|
|
|
class BatchNorm1d(nn.Module): |
|
"""Applies 1d batch normalization to the input tensor. |
|
|
|
Arguments |
|
--------- |
|
input_shape : tuple |
|
The expected shape of the input. Alternatively, use ``input_size``. |
|
input_size : int |
|
The expected size of the input. Alternatively, use ``input_shape``. |
|
eps : float |
|
This value is added to std deviation estimation to improve the numerical |
|
stability. |
|
momentum : float |
|
It is a value used for the running_mean and running_var computation. |
|
affine : bool |
|
When set to True, the affine parameters are learned. |
|
track_running_stats : bool |
|
When set to True, this module tracks the running mean and variance, |
|
and when set to False, this module does not track such statistics. |
|
combine_batch_time : bool |
|
When true, it combines batch an time axis. |
|
skip_transpose : bool |
|
Whether to skip the transposition. |
|
|
|
|
|
Example |
|
------- |
|
>>> input = torch.randn(100, 10) |
|
>>> norm = BatchNorm1d(input_shape=input.shape) |
|
>>> output = norm(input) |
|
>>> output.shape |
|
torch.Size([100, 10]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_shape=None, |
|
input_size=None, |
|
eps=1e-05, |
|
momentum=0.1, |
|
affine=True, |
|
track_running_stats=True, |
|
combine_batch_time=False, |
|
skip_transpose=False, |
|
): |
|
super().__init__() |
|
self.combine_batch_time = combine_batch_time |
|
self.skip_transpose = skip_transpose |
|
|
|
if input_size is None and skip_transpose: |
|
input_size = input_shape[1] |
|
elif input_size is None: |
|
input_size = input_shape[-1] |
|
|
|
self.norm = nn.BatchNorm1d( |
|
input_size, |
|
eps=eps, |
|
momentum=momentum, |
|
affine=affine, |
|
track_running_stats=track_running_stats, |
|
) |
|
|
|
def forward(self, x, *args, **kwargs): |
|
"""Returns the normalized input tensor. |
|
|
|
Arguments |
|
--------- |
|
x : torch.Tensor (batch, time, [channels]) |
|
input to normalize. 2d or 3d tensors are expected in input |
|
4d tensors can be used when combine_dims=True. |
|
|
|
Returns |
|
------- |
|
x_n : torch.Tensor |
|
The normalized outputs. |
|
""" |
|
shape_or = x.shape |
|
if self.combine_batch_time: |
|
if x.ndim == 3: |
|
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) |
|
else: |
|
x = x.reshape( |
|
shape_or[0] * shape_or[1], shape_or[3], shape_or[2] |
|
) |
|
|
|
elif not self.skip_transpose: |
|
x = x.transpose(-1, 1) |
|
|
|
x_n = self.norm(x) |
|
|
|
if self.combine_batch_time: |
|
x_n = x_n.reshape(shape_or) |
|
elif not self.skip_transpose: |
|
x_n = x_n.transpose(1, -1) |
|
|
|
return x_n |