rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import Dict, Optional
from mmengine.registry import MODELS
from torch import nn
MODELS.register_module('Conv1d', module=nn.Conv1d)
MODELS.register_module('Conv2d', module=nn.Conv2d)
MODELS.register_module('Conv3d', module=nn.Conv3d)
MODELS.register_module('Conv', module=nn.Conv2d)
def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
"""Build convolution layer.
Args:
cfg (None or dict): The conv layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate an conv layer.
args (argument list): Arguments passed to the `__init__`
method of the corresponding conv layer.
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
method of the corresponding conv layer.
Returns:
nn.Module: Created conv layer.
"""
if cfg is None:
cfg_ = dict(type='Conv2d')
else:
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
return layer_type(*args, **kwargs, **cfg_) # type: ignore
# Switch registry to the target scope. If `conv_layer` cannot be found
# in the registry, fallback to search `conv_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
conv_layer = registry.get(layer_type)
if conv_layer is None:
raise KeyError(f'Cannot find {conv_layer} in registry under scope '
f'name {registry.scope}')
layer = conv_layer(*args, **kwargs, **cfg_)
return layer