Spaces:
Running
on
Zero
Running
on
Zero
from typing import Literal | |
import torch | |
import torch.nn as nn | |
ConvMode = Literal["CNA", "NAC", "CNAC"] | |
def act(act_type: str, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1) -> nn.Module: | |
"""#### Get the activation layer. | |
#### Args: | |
- `act_type` (str): The type of activation. | |
- `inplace` (bool, optional): Whether to perform the operation in-place. Defaults to True. | |
- `neg_slope` (float, optional): The negative slope for LeakyReLU. Defaults to 0.2. | |
- `n_prelu` (int, optional): The number of PReLU parameters. Defaults to 1. | |
#### Returns: | |
- `nn.Module`: The activation layer. | |
""" | |
act_type = act_type.lower() | |
layer = nn.LeakyReLU(neg_slope, inplace) | |
return layer | |
def get_valid_padding(kernel_size: int, dilation: int) -> int: | |
"""#### Get the valid padding for a convolutional layer. | |
#### Args: | |
- `kernel_size` (int): The size of the kernel. | |
- `dilation` (int): The dilation rate. | |
#### Returns: | |
- `int`: The valid padding. | |
""" | |
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) | |
padding = (kernel_size - 1) // 2 | |
return padding | |
def sequential(*args: nn.Module) -> nn.Sequential: | |
"""#### Create a sequential container. | |
#### Args: | |
- `*args` (nn.Module): The modules to include in the sequential container. | |
#### Returns: | |
- `nn.Sequential`: The sequential container. | |
""" | |
modules = [] | |
for module in args: | |
if isinstance(module, nn.Sequential): | |
for submodule in module.children(): | |
modules.append(submodule) | |
elif isinstance(module, nn.Module): | |
modules.append(module) | |
return nn.Sequential(*modules) | |
def conv_block( | |
in_nc: int, | |
out_nc: int, | |
kernel_size: int, | |
stride: int = 1, | |
dilation: int = 1, | |
groups: int = 1, | |
bias: bool = True, | |
pad_type: str = "zero", | |
norm_type: str | None = None, | |
act_type: str | None = "relu", | |
mode: ConvMode = "CNA", | |
c2x2: bool = False, | |
) -> nn.Sequential: | |
"""#### Create a convolutional block. | |
#### Args: | |
- `in_nc` (int): The number of input channels. | |
- `out_nc` (int): The number of output channels. | |
- `kernel_size` (int): The size of the kernel. | |
- `stride` (int, optional): The stride of the convolution. Defaults to 1. | |
- `dilation` (int, optional): The dilation rate. Defaults to 1. | |
- `groups` (int, optional): The number of groups. Defaults to 1. | |
- `bias` (bool, optional): Whether to include a bias term. Defaults to True. | |
- `pad_type` (str, optional): The type of padding. Defaults to "zero". | |
- `norm_type` (str | None, optional): The type of normalization. Defaults to None. | |
- `act_type` (str | None, optional): The type of activation. Defaults to "relu". | |
- `mode` (ConvMode, optional): The mode of the convolution. Defaults to "CNA". | |
- `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False. | |
#### Returns: | |
- `nn.Sequential`: The convolutional block. | |
""" | |
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) | |
padding = get_valid_padding(kernel_size, dilation) | |
padding = padding if pad_type == "zero" else 0 | |
c = nn.Conv2d( | |
in_nc, | |
out_nc, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
groups=groups, | |
) | |
a = act(act_type) if act_type else None | |
if mode in ("CNA", "CNAC"): | |
return sequential(None, c, None, a) | |
def upconv_block( | |
in_nc: int, | |
out_nc: int, | |
upscale_factor: int = 2, | |
kernel_size: int = 3, | |
stride: int = 1, | |
bias: bool = True, | |
pad_type: str = "zero", | |
norm_type: str | None = None, | |
act_type: str = "relu", | |
mode: str = "nearest", | |
c2x2: bool = False, | |
) -> nn.Sequential: | |
"""#### Create an upsampling convolutional block. | |
#### Args: | |
- `in_nc` (int): The number of input channels. | |
- `out_nc` (int): The number of output channels. | |
- `upscale_factor` (int, optional): The upscale factor. Defaults to 2. | |
- `kernel_size` (int, optional): The size of the kernel. Defaults to 3. | |
- `stride` (int, optional): The stride of the convolution. Defaults to 1. | |
- `bias` (bool, optional): Whether to include a bias term. Defaults to True. | |
- `pad_type` (str, optional): The type of padding. Defaults to "zero". | |
- `norm_type` (str | None, optional): The type of normalization. Defaults to None. | |
- `act_type` (str, optional): The type of activation. Defaults to "relu". | |
- `mode` (str, optional): The mode of upsampling. Defaults to "nearest". | |
- `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False. | |
#### Returns: | |
- `nn.Sequential`: The upsampling convolutional block. | |
""" | |
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) | |
conv = conv_block( | |
in_nc, | |
out_nc, | |
kernel_size, | |
stride, | |
bias=bias, | |
pad_type=pad_type, | |
norm_type=norm_type, | |
act_type=act_type, | |
c2x2=c2x2, | |
) | |
return sequential(upsample, conv) | |
class ShortcutBlock(nn.Module): | |
"""#### Elementwise sum the output of a submodule to its input.""" | |
def __init__(self, submodule: nn.Module): | |
"""#### Initialize the ShortcutBlock. | |
#### Args: | |
- `submodule` (nn.Module): The submodule to apply. | |
""" | |
super(ShortcutBlock, self).__init__() | |
self.sub = submodule | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
output = x + self.sub(x) | |
return output |