File size: 2,356 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
                             MomentumAnnealingEMA, StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .test_time_aug import BaseTTAModel
from .utils import (convert_sync_batchnorm, detect_anomalous_params,
                    merge_dict, revert_sync_batchnorm, stack_batch)
from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit,
                          KaimingInit, NormalInit, PretrainedInit,
                          TruncNormalInit, UniformInit, XavierInit,
                          bias_init_with_prob, caffe2_xavier_init,
                          constant_init, initialize, kaiming_init, normal_init,
                          trunc_normal_init, uniform_init, update_init_info,
                          xavier_init)
from .wrappers import (MMDistributedDataParallel,
                       MMSeparateDistributedDataParallel, is_model_wrapper)

__all__ = [
    'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel',
    'StochasticWeightAverage', 'ExponentialMovingAverage',
    'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
    'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
    'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
    'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info',
    'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
    'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
    'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
    'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
    'Caffe2XavierInit', 'PretrainedInit', 'initialize',
    'convert_sync_batchnorm', 'BaseTTAModel'
]

if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
    from .wrappers import MMFullyShardedDataParallel  # noqa:F401
    __all__.append('MMFullyShardedDataParallel')