File size: 38,837 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
from typing import Optional, Any

import torch
from torch import Tensor
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer

from .. import functional as F
from .. import init
from ._functions import SyncBatchNorm as sync_batch_norm
from .lazy import LazyModuleMixin
from .module import Module

__all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d',
           'LazyBatchNorm3d', 'SyncBatchNorm']


class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm."""

    _version = 2
    __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
    num_features: int
    eps: float
    momentum: float
    affine: bool
    track_running_stats: bool
    # WARNING: weight and bias purposely not defined here.
    # See https://github.com/pytorch/pytorch/issues/39670

    def __init__(

        self,

        num_features: int,

        eps: float = 1e-5,

        momentum: float = 0.1,

        affine: bool = True,

        track_running_stats: bool = True,

        device=None,

        dtype=None

    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
            self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
            self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
            self.running_mean: Optional[Tensor]
            self.running_var: Optional[Tensor]
            self.register_buffer('num_batches_tracked',
                                 torch.tensor(0, dtype=torch.long,
                                              **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
            self.num_batches_tracked: Optional[Tensor]
        else:
            self.register_buffer("running_mean", None)
            self.register_buffer("running_var", None)
            self.register_buffer("num_batches_tracked", None)
        self.reset_parameters()

    def reset_running_stats(self) -> None:
        if self.track_running_stats:
            # running_mean/running_var/num_batches... are registered at runtime depending
            # if self.track_running_stats is on
            self.running_mean.zero_()  # type: ignore[union-attr]
            self.running_var.fill_(1)  # type: ignore[union-attr]
            self.num_batches_tracked.zero_()  # type: ignore[union-attr,operator]

    def reset_parameters(self) -> None:
        self.reset_running_stats()
        if self.affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def _check_input_dim(self, input):
        raise NotImplementedError

    def extra_repr(self):
        return (
            "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
            "track_running_stats={track_running_stats}".format(**self.__dict__)
        )

    def _load_from_state_dict(

        self,

        state_dict,

        prefix,

        local_metadata,

        strict,

        missing_keys,

        unexpected_keys,

        error_msgs,

    ):
        version = local_metadata.get("version", None)

        if (version is None or version < 2) and self.track_running_stats:
            # at version 2: added num_batches_tracked buffer
            #               this should have a default value of 0
            num_batches_tracked_key = prefix + "num_batches_tracked"
            if num_batches_tracked_key not in state_dict:
                state_dict[num_batches_tracked_key] = (
                    self.num_batches_tracked
                    if self.num_batches_tracked is not None and self.num_batches_tracked.device != torch.device('meta')
                    else torch.tensor(0, dtype=torch.long)
                )

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )


class _BatchNorm(_NormBase):
    def __init__(

        self,

        num_features: int,

        eps: float = 1e-5,

        momentum: float = 0.1,

        affine: bool = True,

        track_running_stats: bool = True,

        device=None,

        dtype=None

    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked.add_(1)  # type: ignore[has-type]
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        r"""

        Decide whether the mini-batch stats should be used for normalization rather than the buffers.

        Mini-batch stats are used in training mode, and in eval mode when buffers are None.

        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        r"""

        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be

        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are

        used for normalization (i.e. in eval mode when buffers are not None).

        """
        return F.batch_norm(
            input,
            # If buffers are not to be tracked, ensure that they won't be updated
            self.running_mean
            if not self.training or self.track_running_stats
            else None,
            self.running_var if not self.training or self.track_running_stats else None,
            self.weight,
            self.bias,
            bn_training,
            exponential_average_factor,
            self.eps,
        )


class _LazyNormBase(LazyModuleMixin, _NormBase):

    weight: UninitializedParameter  # type: ignore[assignment]
    bias: UninitializedParameter  # type: ignore[assignment]

    def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,

                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(
            # affine and track_running_stats are hardcoded to False to
            # avoid creating tensors that will soon be overwritten.
            0,
            eps,
            momentum,
            False,
            False,
            **factory_kwargs,
        )
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = UninitializedParameter(**factory_kwargs)
            self.bias = UninitializedParameter(**factory_kwargs)
        if self.track_running_stats:
            self.running_mean = UninitializedBuffer(**factory_kwargs)
            self.running_var = UninitializedBuffer(**factory_kwargs)
            self.num_batches_tracked = torch.tensor(
                0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})

    def reset_parameters(self) -> None:
        if not self.has_uninitialized_params() and self.num_features != 0:
            super().reset_parameters()

    def initialize_parameters(self, input) -> None:  # type: ignore[override]
        if self.has_uninitialized_params():
            self.num_features = input.shape[1]
            if self.affine:
                assert isinstance(self.weight, UninitializedParameter)
                assert isinstance(self.bias, UninitializedParameter)
                self.weight.materialize((self.num_features,))
                self.bias.materialize((self.num_features,))
            if self.track_running_stats:
                self.running_mean.materialize((self.num_features,))  # type:ignore[union-attr]
                self.running_var.materialize((self.num_features,))  # type:ignore[union-attr]
            self.reset_parameters()


class BatchNorm1d(_BatchNorm):
    r"""Applies Batch Normalization over a 2D or 3D input.



    Method described in the paper

    `Batch Normalization: Accelerating Deep Network Training by Reducing

    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .



    .. math::



        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta



    The mean and standard-deviation are calculated per-dimension over

    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors

    of size `C` (where `C` is the number of features or channels of the input). By default, the

    elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.

    At train time in the forward pass, the standard-deviation is calculated via the biased estimator,

    equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the

    moving average of the standard-deviation is calculated via the unbiased  estimator, equivalent to

    ``torch.var(input, unbiased=True)``.



    Also by default, during training this layer keeps running estimates of its

    computed mean and variance, which are then used for normalization during

    evaluation. The running estimates are kept with a default :attr:`momentum`

    of 0.1.



    If :attr:`track_running_stats` is set to ``False``, this layer then does not

    keep running estimates, and batch statistics are instead used during

    evaluation time as well.



    .. note::

        This :attr:`momentum` argument is different from one used in optimizer

        classes and the conventional notion of momentum. Mathematically, the

        update rule for running statistics here is

        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,

        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the

        new observed value.



    Because the Batch Normalization is done over the `C` dimension, computing statistics

    on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.



    Args:

        num_features: number of features or channels :math:`C` of the input

        eps: a value added to the denominator for numerical stability.

            Default: 1e-5

        momentum: the value used for the running_mean and running_var

            computation. Can be set to ``None`` for cumulative moving average

            (i.e. simple average). Default: 0.1

        affine: a boolean value that when set to ``True``, this module has

            learnable affine parameters. Default: ``True``

        track_running_stats: a boolean value that when set to ``True``, this

            module tracks the running mean and variance, and when set to ``False``,

            this module does not track such statistics, and initializes statistics

            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.

            When these buffers are ``None``, this module always uses batch statistics.

            in both training and eval modes. Default: ``True``



    Shape:

        - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,

          :math:`C` is the number of features or channels, and :math:`L` is the sequence length

        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)



    Examples::



        >>> # With Learnable Parameters

        >>> m = nn.BatchNorm1d(100)

        >>> # Without Learnable Parameters

        >>> m = nn.BatchNorm1d(100, affine=False)

        >>> input = torch.randn(20, 100)

        >>> output = m(input)

    """

    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError(
                f"expected 2D or 3D input (got {input.dim()}D input)"
            )


class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
    r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.



    Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred

    from the ``input.size(1)``.

    The attributes that will be lazily initialized are `weight`, `bias`,

    `running_mean` and `running_var`.



    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation

    on lazy modules and their limitations.



    Args:

        eps: a value added to the denominator for numerical stability.

            Default: 1e-5

        momentum: the value used for the running_mean and running_var

            computation. Can be set to ``None`` for cumulative moving average

            (i.e. simple average). Default: 0.1

        affine: a boolean value that when set to ``True``, this module has

            learnable affine parameters. Default: ``True``

        track_running_stats: a boolean value that when set to ``True``, this

            module tracks the running mean and variance, and when set to ``False``,

            this module does not track such statistics, and initializes statistics

            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.

            When these buffers are ``None``, this module always uses batch statistics.

            in both training and eval modes. Default: ``True``

    """

    cls_to_become = BatchNorm1d  # type: ignore[assignment]

    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError(
                f"expected 2D or 3D input (got {input.dim()}D input)"
            )


class BatchNorm2d(_BatchNorm):
    r"""Applies Batch Normalization over a 4D input.



    4D is a mini-batch of 2D inputs

    with additional channel dimension. Method described in the paper

    `Batch Normalization: Accelerating Deep Network Training by Reducing

    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .



    .. math::



        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta



    The mean and standard-deviation are calculated per-dimension over

    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors

    of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set

    to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the

    standard-deviation is calculated via the biased estimator, equivalent to

    ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the

    standard-deviation is calculated via the unbiased  estimator, equivalent to

    ``torch.var(input, unbiased=True)``.



    Also by default, during training this layer keeps running estimates of its

    computed mean and variance, which are then used for normalization during

    evaluation. The running estimates are kept with a default :attr:`momentum`

    of 0.1.



    If :attr:`track_running_stats` is set to ``False``, this layer then does not

    keep running estimates, and batch statistics are instead used during

    evaluation time as well.



    .. note::

        This :attr:`momentum` argument is different from one used in optimizer

        classes and the conventional notion of momentum. Mathematically, the

        update rule for running statistics here is

        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,

        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the

        new observed value.



    Because the Batch Normalization is done over the `C` dimension, computing statistics

    on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.



    Args:

        num_features: :math:`C` from an expected input of size

            :math:`(N, C, H, W)`

        eps: a value added to the denominator for numerical stability.

            Default: 1e-5

        momentum: the value used for the running_mean and running_var

            computation. Can be set to ``None`` for cumulative moving average

            (i.e. simple average). Default: 0.1

        affine: a boolean value that when set to ``True``, this module has

            learnable affine parameters. Default: ``True``

        track_running_stats: a boolean value that when set to ``True``, this

            module tracks the running mean and variance, and when set to ``False``,

            this module does not track such statistics, and initializes statistics

            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.

            When these buffers are ``None``, this module always uses batch statistics.

            in both training and eval modes. Default: ``True``



    Shape:

        - Input: :math:`(N, C, H, W)`

        - Output: :math:`(N, C, H, W)` (same shape as input)



    Examples::



        >>> # With Learnable Parameters

        >>> m = nn.BatchNorm2d(100)

        >>> # Without Learnable Parameters

        >>> m = nn.BatchNorm2d(100, affine=False)

        >>> input = torch.randn(20, 100, 35, 45)

        >>> output = m(input)

    """

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError(f"expected 4D input (got {input.dim()}D input)")


class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
    r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.



    Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred

    from the ``input.size(1)``.

    The attributes that will be lazily initialized are `weight`, `bias`,

    `running_mean` and `running_var`.



    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation

    on lazy modules and their limitations.



    Args:

        eps: a value added to the denominator for numerical stability.

            Default: 1e-5

        momentum: the value used for the running_mean and running_var

            computation. Can be set to ``None`` for cumulative moving average

            (i.e. simple average). Default: 0.1

        affine: a boolean value that when set to ``True``, this module has

            learnable affine parameters. Default: ``True``

        track_running_stats: a boolean value that when set to ``True``, this

            module tracks the running mean and variance, and when set to ``False``,

            this module does not track such statistics, and initializes statistics

            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.

            When these buffers are ``None``, this module always uses batch statistics.

            in both training and eval modes. Default: ``True``

    """

    cls_to_become = BatchNorm2d  # type: ignore[assignment]

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError(f"expected 4D input (got {input.dim()}D input)")


class BatchNorm3d(_BatchNorm):
    r"""Applies Batch Normalization over a 5D input.



    5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper

    `Batch Normalization: Accelerating Deep Network Training by Reducing

    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .



    .. math::



        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta



    The mean and standard-deviation are calculated per-dimension over

    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors

    of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set

    to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the

    standard-deviation is calculated via the biased estimator, equivalent to

    ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the

    standard-deviation is calculated via the unbiased  estimator, equivalent to

    ``torch.var(input, unbiased=True)``.



    Also by default, during training this layer keeps running estimates of its

    computed mean and variance, which are then used for normalization during

    evaluation. The running estimates are kept with a default :attr:`momentum`

    of 0.1.



    If :attr:`track_running_stats` is set to ``False``, this layer then does not

    keep running estimates, and batch statistics are instead used during

    evaluation time as well.



    .. note::

        This :attr:`momentum` argument is different from one used in optimizer

        classes and the conventional notion of momentum. Mathematically, the

        update rule for running statistics here is

        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,

        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the

        new observed value.



    Because the Batch Normalization is done over the `C` dimension, computing statistics

    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization

    or Spatio-temporal Batch Normalization.



    Args:

        num_features: :math:`C` from an expected input of size

            :math:`(N, C, D, H, W)`

        eps: a value added to the denominator for numerical stability.

            Default: 1e-5

        momentum: the value used for the running_mean and running_var

            computation. Can be set to ``None`` for cumulative moving average

            (i.e. simple average). Default: 0.1

        affine: a boolean value that when set to ``True``, this module has

            learnable affine parameters. Default: ``True``

        track_running_stats: a boolean value that when set to ``True``, this

            module tracks the running mean and variance, and when set to ``False``,

            this module does not track such statistics, and initializes statistics

            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.

            When these buffers are ``None``, this module always uses batch statistics.

            in both training and eval modes. Default: ``True``



    Shape:

        - Input: :math:`(N, C, D, H, W)`

        - Output: :math:`(N, C, D, H, W)` (same shape as input)



    Examples::



        >>> # With Learnable Parameters

        >>> m = nn.BatchNorm3d(100)

        >>> # Without Learnable Parameters

        >>> m = nn.BatchNorm3d(100, affine=False)

        >>> input = torch.randn(20, 100, 35, 45, 10)

        >>> output = m(input)

    """

    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError(f"expected 5D input (got {input.dim()}D input)")


class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
    r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.



    Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred

    from the ``input.size(1)``.

    The attributes that will be lazily initialized are `weight`, `bias`,

    `running_mean` and `running_var`.



    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation

    on lazy modules and their limitations.



    Args:

        eps: a value added to the denominator for numerical stability.

            Default: 1e-5

        momentum: the value used for the running_mean and running_var

            computation. Can be set to ``None`` for cumulative moving average

            (i.e. simple average). Default: 0.1

        affine: a boolean value that when set to ``True``, this module has

            learnable affine parameters. Default: ``True``

        track_running_stats: a boolean value that when set to ``True``, this

            module tracks the running mean and variance, and when set to ``False``,

            this module does not track such statistics, and initializes statistics

            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.

            When these buffers are ``None``, this module always uses batch statistics.

            in both training and eval modes. Default: ``True``

    """

    cls_to_become = BatchNorm3d  # type: ignore[assignment]

    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError(f"expected 5D input (got {input.dim()}D input)")


class SyncBatchNorm(_BatchNorm):
    r"""Applies Batch Normalization over a N-Dimensional input.



    The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper

    `Batch Normalization: Accelerating Deep Network Training by Reducing

    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .



    .. math::



        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta



    The mean and standard-deviation are calculated per-dimension over all

    mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`

    are learnable parameter vectors of size `C` (where `C` is the input size).

    By default, the elements of :math:`\gamma` are sampled from

    :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.

    The standard-deviation is calculated via the biased estimator, equivalent to

    `torch.var(input, unbiased=False)`.



    Also by default, during training this layer keeps running estimates of its

    computed mean and variance, which are then used for normalization during

    evaluation. The running estimates are kept with a default :attr:`momentum`

    of 0.1.



    If :attr:`track_running_stats` is set to ``False``, this layer then does not

    keep running estimates, and batch statistics are instead used during

    evaluation time as well.



    .. note::

        This :attr:`momentum` argument is different from one used in optimizer

        classes and the conventional notion of momentum. Mathematically, the

        update rule for running statistics here is

        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,

        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the

        new observed value.



    Because the Batch Normalization is done for each channel in the ``C`` dimension, computing

    statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch

    Normalization or Spatio-temporal Batch Normalization.



    Currently :class:`SyncBatchNorm` only supports

    :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use

    :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert

    :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping

    Network with DDP.



    Args:

        num_features: :math:`C` from an expected input of size

            :math:`(N, C, +)`

        eps: a value added to the denominator for numerical stability.

            Default: ``1e-5``

        momentum: the value used for the running_mean and running_var

            computation. Can be set to ``None`` for cumulative moving average

            (i.e. simple average). Default: 0.1

        affine: a boolean value that when set to ``True``, this module has

            learnable affine parameters. Default: ``True``

        track_running_stats: a boolean value that when set to ``True``, this

            module tracks the running mean and variance, and when set to ``False``,

            this module does not track such statistics, and initializes statistics

            buffers :attr:`running_mean` and :attr:`running_var` as ``None``.

            When these buffers are ``None``, this module always uses batch statistics.

            in both training and eval modes. Default: ``True``

        process_group: synchronization of stats happen within each process group

            individually. Default behavior is synchronization across the whole

            world



    Shape:

        - Input: :math:`(N, C, +)`

        - Output: :math:`(N, C, +)` (same shape as input)



    .. note::

        Synchronization of batchnorm statistics occurs only while training, i.e.

        synchronization is disabled when ``model.eval()`` is set or if

        ``self.training`` is otherwise ``False``.



    Examples::



        >>> # xdoctest: +SKIP

        >>> # With Learnable Parameters

        >>> m = nn.SyncBatchNorm(100)

        >>> # creating process group (optional)

        >>> # ranks is a list of int identifying rank ids.

        >>> ranks = list(range(8))

        >>> r1, r2 = ranks[:4], ranks[4:]

        >>> # Note: every rank calls into new_group for every

        >>> # process group created, even if that rank is not

        >>> # part of the group.

        >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]

        >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]

        >>> # Without Learnable Parameters

        >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)

        >>> input = torch.randn(20, 100, 35, 45, 10)

        >>> output = m(input)



        >>> # network is nn.BatchNorm layer

        >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)

        >>> # only single gpu per process is currently supported

        >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(

        >>>                         sync_bn_network,

        >>>                         device_ids=[args.local_rank],

        >>>                         output_device=args.local_rank)

    """

    def __init__(

        self,

        num_features: int,

        eps: float = 1e-5,

        momentum: float = 0.1,

        affine: bool = True,

        track_running_stats: bool = True,

        process_group: Optional[Any] = None,

        device=None,

        dtype=None

    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )
        self.process_group = process_group

    def _check_input_dim(self, input):
        if input.dim() < 2:
            raise ValueError(
                f"expected at least 2D input (got {input.dim()}D input)"
            )

    def _check_non_zero_input_channels(self, input):
        if input.size(1) == 0:
            raise ValueError(
                "SyncBatchNorm number of input channels should be non-zero"
            )

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)
        self._check_non_zero_input_channels(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            assert self.num_batches_tracked is not None
            self.num_batches_tracked.add_(1)
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        r"""

        Decide whether the mini-batch stats should be used for normalization rather than the buffers.

        Mini-batch stats are used in training mode, and in eval mode when buffers are None.

        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        r"""

        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be

        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are

        used for normalization (i.e. in eval mode when buffers are not None).

        """
        # If buffers are not to be tracked, ensure that they won't be updated
        running_mean = (
            self.running_mean if not self.training or self.track_running_stats else None
        )
        running_var = (
            self.running_var if not self.training or self.track_running_stats else None
        )

        # Don't sync batchnorm stats in inference mode (model.eval()).
        need_sync = (bn_training and self.training and
                     torch.distributed.is_available() and torch.distributed.is_initialized())
        if need_sync:
            # currently only GPU/PrivateUse1 input is supported
            if input.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]:
                raise ValueError("SyncBatchNorm expected input tensor to be on GPU or "
                                 f"{torch._C._get_privateuse1_backend_name()}")

            process_group = torch.distributed.group.WORLD
            if self.process_group:
                process_group = self.process_group
            world_size = torch.distributed.get_world_size(process_group)
            need_sync = world_size > 1

        # fallback to framework BN when synchronization is not necessary
        if not need_sync:
            return F.batch_norm(
                input,
                running_mean,
                running_var,
                self.weight,
                self.bias,
                bn_training,
                exponential_average_factor,
                self.eps,
            )
        else:
            assert bn_training
            return sync_batch_norm.apply(
                input,
                self.weight,
                self.bias,
                running_mean,
                running_var,
                self.eps,
                exponential_average_factor,
                process_group,  # type: ignore[possibly-undefined]
                world_size,  # type: ignore[possibly-undefined]
            )

    @classmethod
    def convert_sync_batchnorm(cls, module, process_group=None):
        r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.



        Args:

            module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers

            process_group (optional): process group to scope synchronization,

                default is the whole world



        Returns:

            The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`

            layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,

            a new :class:`torch.nn.SyncBatchNorm` layer object will be returned

            instead.



        Example::



            >>> # Network with nn.BatchNorm layer

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)

            >>> module = torch.nn.Sequential(

            >>>            torch.nn.Linear(20, 100),

            >>>            torch.nn.BatchNorm1d(100),

            >>>          ).cuda()

            >>> # creating process group (optional)

            >>> # ranks is a list of int identifying rank ids.

            >>> ranks = list(range(8))

            >>> r1, r2 = ranks[:4], ranks[4:]

            >>> # Note: every rank calls into new_group for every

            >>> # process group created, even if that rank is not

            >>> # part of the group.

            >>> # xdoctest: +SKIP("distributed")

            >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]

            >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]

            >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)



        """
        module_output = module
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module_output = torch.nn.SyncBatchNorm(
                module.num_features,
                module.eps,
                module.momentum,
                module.affine,
                module.track_running_stats,
                process_group,
            )
            if module.affine:
                with torch.no_grad():
                    module_output.weight = module.weight
                    module_output.bias = module.bias
            module_output.running_mean = module.running_mean
            module_output.running_var = module.running_var
            module_output.num_batches_tracked = module.num_batches_tracked
            module_output.training = module.training
            if hasattr(module, "qconfig"):
                module_output.qconfig = module.qconfig
        for name, child in module.named_children():
            module_output.add_module(
                name, cls.convert_sync_batchnorm(child, process_group)
            )
        del module
        return module_output