from typing import Literal import torch import torch.nn as nn ConvMode = Literal["CNA", "NAC", "CNAC"] def act(act_type: str, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1) -> nn.Module: """#### Get the activation layer. #### Args: - `act_type` (str): The type of activation. - `inplace` (bool, optional): Whether to perform the operation in-place. Defaults to True. - `neg_slope` (float, optional): The negative slope for LeakyReLU. Defaults to 0.2. - `n_prelu` (int, optional): The number of PReLU parameters. Defaults to 1. #### Returns: - `nn.Module`: The activation layer. """ act_type = act_type.lower() layer = nn.LeakyReLU(neg_slope, inplace) return layer def get_valid_padding(kernel_size: int, dilation: int) -> int: """#### Get the valid padding for a convolutional layer. #### Args: - `kernel_size` (int): The size of the kernel. - `dilation` (int): The dilation rate. #### Returns: - `int`: The valid padding. """ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) padding = (kernel_size - 1) // 2 return padding def sequential(*args: nn.Module) -> nn.Sequential: """#### Create a sequential container. #### Args: - `*args` (nn.Module): The modules to include in the sequential container. #### Returns: - `nn.Sequential`: The sequential container. """ modules = [] for module in args: if isinstance(module, nn.Sequential): for submodule in module.children(): modules.append(submodule) elif isinstance(module, nn.Module): modules.append(module) return nn.Sequential(*modules) def conv_block( in_nc: int, out_nc: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, pad_type: str = "zero", norm_type: str | None = None, act_type: str | None = "relu", mode: ConvMode = "CNA", c2x2: bool = False, ) -> nn.Sequential: """#### Create a convolutional block. #### Args: - `in_nc` (int): The number of input channels. - `out_nc` (int): The number of output channels. - `kernel_size` (int): The size of the kernel. - `stride` (int, optional): The stride of the convolution. Defaults to 1. - `dilation` (int, optional): The dilation rate. Defaults to 1. - `groups` (int, optional): The number of groups. Defaults to 1. - `bias` (bool, optional): Whether to include a bias term. Defaults to True. - `pad_type` (str, optional): The type of padding. Defaults to "zero". - `norm_type` (str | None, optional): The type of normalization. Defaults to None. - `act_type` (str | None, optional): The type of activation. Defaults to "relu". - `mode` (ConvMode, optional): The mode of the convolution. Defaults to "CNA". - `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False. #### Returns: - `nn.Sequential`: The convolutional block. """ assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) padding = get_valid_padding(kernel_size, dilation) padding = padding if pad_type == "zero" else 0 c = nn.Conv2d( in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups, ) a = act(act_type) if act_type else None if mode in ("CNA", "CNAC"): return sequential(None, c, None, a) def upconv_block( in_nc: int, out_nc: int, upscale_factor: int = 2, kernel_size: int = 3, stride: int = 1, bias: bool = True, pad_type: str = "zero", norm_type: str | None = None, act_type: str = "relu", mode: str = "nearest", c2x2: bool = False, ) -> nn.Sequential: """#### Create an upsampling convolutional block. #### Args: - `in_nc` (int): The number of input channels. - `out_nc` (int): The number of output channels. - `upscale_factor` (int, optional): The upscale factor. Defaults to 2. - `kernel_size` (int, optional): The size of the kernel. Defaults to 3. - `stride` (int, optional): The stride of the convolution. Defaults to 1. - `bias` (bool, optional): Whether to include a bias term. Defaults to True. - `pad_type` (str, optional): The type of padding. Defaults to "zero". - `norm_type` (str | None, optional): The type of normalization. Defaults to None. - `act_type` (str, optional): The type of activation. Defaults to "relu". - `mode` (str, optional): The mode of upsampling. Defaults to "nearest". - `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False. #### Returns: - `nn.Sequential`: The upsampling convolutional block. """ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) conv = conv_block( in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, c2x2=c2x2, ) return sequential(upsample, conv) class ShortcutBlock(nn.Module): """#### Elementwise sum the output of a submodule to its input.""" def __init__(self, submodule: nn.Module): """#### Initialize the ShortcutBlock. #### Args: - `submodule` (nn.Module): The submodule to apply. """ super(ShortcutBlock, self).__init__() self.sub = submodule def forward(self, x: torch.Tensor) -> torch.Tensor: """#### Forward pass. #### Args: - `x` (torch.Tensor): The input tensor. #### Returns: - `torch.Tensor`: The output tensor. """ output = x + self.sub(x) return output