# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import inspect from typing import Dict, Optional from mmengine.registry import MODELS from torch import nn MODELS.register_module('Conv1d', module=nn.Conv1d) MODELS.register_module('Conv2d', module=nn.Conv2d) MODELS.register_module('Conv3d', module=nn.Conv3d) MODELS.register_module('Conv', module=nn.Conv2d) def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module: """Build convolution layer. Args: cfg (None or dict): The conv layer config, which should contain: - type (str): Layer type. - layer args: Args needed to instantiate an conv layer. args (argument list): Arguments passed to the `__init__` method of the corresponding conv layer. kwargs (keyword arguments): Keyword arguments passed to the `__init__` method of the corresponding conv layer. Returns: nn.Module: Created conv layer. """ if cfg is None: cfg_ = dict(type='Conv2d') else: if not isinstance(cfg, dict): raise TypeError('cfg must be a dict') if 'type' not in cfg: raise KeyError('the cfg dict must contain the key "type"') cfg_ = cfg.copy() layer_type = cfg_.pop('type') if inspect.isclass(layer_type): return layer_type(*args, **kwargs, **cfg_) # type: ignore # Switch registry to the target scope. If `conv_layer` cannot be found # in the registry, fallback to search `conv_layer` in the # mmengine.MODELS. with MODELS.switch_scope_and_registry(None) as registry: conv_layer = registry.get(layer_type) if conv_layer is None: raise KeyError(f'Cannot find {conv_layer} in registry under scope ' f'name {registry.scope}') layer = conv_layer(*args, **kwargs, **cfg_) return layer