# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import inspect from typing import Dict import torch import torch.nn as nn import torch.nn.functional as F from mmengine.model import xavier_init from mmengine.registry import MODELS MODELS.register_module('nearest', module=nn.Upsample) MODELS.register_module('bilinear', module=nn.Upsample) @MODELS.register_module(name='pixel_shuffle') class PixelShufflePack(nn.Module): """Pixel Shuffle upsample layer. This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to achieve a simple upsampling with pixel shuffle. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. scale_factor (int): Upsample ratio. upsample_kernel (int): Kernel size of the conv layer to expand the channels. """ def __init__(self, in_channels: int, out_channels: int, scale_factor: int, upsample_kernel: int): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.scale_factor = scale_factor self.upsample_kernel = upsample_kernel self.upsample_conv = nn.Conv2d( self.in_channels, self.out_channels * scale_factor * scale_factor, self.upsample_kernel, padding=(self.upsample_kernel - 1) // 2) self.init_weights() def init_weights(self): xavier_init(self.upsample_conv, distribution='uniform') def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.upsample_conv(x) x = F.pixel_shuffle(x, self.scale_factor) return x def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: """Build upsample layer. Args: cfg (dict): The upsample layer config, which should contain: - type (str): Layer type. - scale_factor (int): Upsample ratio, which is not applicable to deconv. - layer args: Args needed to instantiate a upsample layer. args (argument list): Arguments passed to the ``__init__`` method of the corresponding conv layer. kwargs (keyword arguments): Keyword arguments passed to the ``__init__`` method of the corresponding conv layer. Returns: nn.Module: Created upsample layer. """ if not isinstance(cfg, dict): raise TypeError(f'cfg must be a dict, but got {type(cfg)}') if 'type' not in cfg: raise KeyError( f'the cfg dict must contain the key "type", but got {cfg}') cfg_ = cfg.copy() layer_type = cfg_.pop('type') if inspect.isclass(layer_type): upsample = layer_type # Switch registry to the target scope. If `upsample` cannot be found # in the registry, fallback to search `upsample` in the # mmengine.MODELS. else: with MODELS.switch_scope_and_registry(None) as registry: upsample = registry.get(layer_type) if upsample is None: raise KeyError(f'Cannot find {upsample} in registry under scope ' f'name {registry.scope}') if upsample is nn.Upsample: cfg_['mode'] = layer_type layer = upsample(*args, **kwargs, **cfg_) return layer