Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from mmengine.utils.dl_utils import TORCH_VERSION | |
from mmengine.utils.version_utils import digit_version | |
from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage, | |
MomentumAnnealingEMA, StochasticWeightAverage) | |
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor | |
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential | |
from .test_time_aug import BaseTTAModel | |
from .utils import (convert_sync_batchnorm, detect_anomalous_params, | |
merge_dict, revert_sync_batchnorm, stack_batch) | |
from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit, | |
KaimingInit, NormalInit, PretrainedInit, | |
TruncNormalInit, UniformInit, XavierInit, | |
bias_init_with_prob, caffe2_xavier_init, | |
constant_init, initialize, kaiming_init, normal_init, | |
trunc_normal_init, uniform_init, update_init_info, | |
xavier_init) | |
from .wrappers import (MMDistributedDataParallel, | |
MMSeparateDistributedDataParallel, is_model_wrapper) | |
__all__ = [ | |
'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel', | |
'StochasticWeightAverage', 'ExponentialMovingAverage', | |
'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor', | |
'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule', | |
'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList', | |
'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info', | |
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init', | |
'uniform_init', 'kaiming_init', 'caffe2_xavier_init', | |
'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit', | |
'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit', | |
'Caffe2XavierInit', 'PretrainedInit', 'initialize', | |
'convert_sync_batchnorm', 'BaseTTAModel' | |
] | |
if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): | |
from .wrappers import MMFullyShardedDataParallel # noqa:F401 | |
__all__.append('MMFullyShardedDataParallel') | |