Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
from typing import Literal
import torch
import torch.nn as nn
ConvMode = Literal["CNA", "NAC", "CNAC"]
def act(act_type: str, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1) -> nn.Module:
"""#### Get the activation layer.
#### Args:
- `act_type` (str): The type of activation.
- `inplace` (bool, optional): Whether to perform the operation in-place. Defaults to True.
- `neg_slope` (float, optional): The negative slope for LeakyReLU. Defaults to 0.2.
- `n_prelu` (int, optional): The number of PReLU parameters. Defaults to 1.
#### Returns:
- `nn.Module`: The activation layer.
"""
act_type = act_type.lower()
layer = nn.LeakyReLU(neg_slope, inplace)
return layer
def get_valid_padding(kernel_size: int, dilation: int) -> int:
"""#### Get the valid padding for a convolutional layer.
#### Args:
- `kernel_size` (int): The size of the kernel.
- `dilation` (int): The dilation rate.
#### Returns:
- `int`: The valid padding.
"""
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
def sequential(*args: nn.Module) -> nn.Sequential:
"""#### Create a sequential container.
#### Args:
- `*args` (nn.Module): The modules to include in the sequential container.
#### Returns:
- `nn.Sequential`: The sequential container.
"""
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv_block(
in_nc: int,
out_nc: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
pad_type: str = "zero",
norm_type: str | None = None,
act_type: str | None = "relu",
mode: ConvMode = "CNA",
c2x2: bool = False,
) -> nn.Sequential:
"""#### Create a convolutional block.
#### Args:
- `in_nc` (int): The number of input channels.
- `out_nc` (int): The number of output channels.
- `kernel_size` (int): The size of the kernel.
- `stride` (int, optional): The stride of the convolution. Defaults to 1.
- `dilation` (int, optional): The dilation rate. Defaults to 1.
- `groups` (int, optional): The number of groups. Defaults to 1.
- `bias` (bool, optional): Whether to include a bias term. Defaults to True.
- `pad_type` (str, optional): The type of padding. Defaults to "zero".
- `norm_type` (str | None, optional): The type of normalization. Defaults to None.
- `act_type` (str | None, optional): The type of activation. Defaults to "relu".
- `mode` (ConvMode, optional): The mode of the convolution. Defaults to "CNA".
- `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False.
#### Returns:
- `nn.Sequential`: The convolutional block.
"""
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
padding = get_valid_padding(kernel_size, dilation)
padding = padding if pad_type == "zero" else 0
c = nn.Conv2d(
in_nc,
out_nc,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
groups=groups,
)
a = act(act_type) if act_type else None
if mode in ("CNA", "CNAC"):
return sequential(None, c, None, a)
def upconv_block(
in_nc: int,
out_nc: int,
upscale_factor: int = 2,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
pad_type: str = "zero",
norm_type: str | None = None,
act_type: str = "relu",
mode: str = "nearest",
c2x2: bool = False,
) -> nn.Sequential:
"""#### Create an upsampling convolutional block.
#### Args:
- `in_nc` (int): The number of input channels.
- `out_nc` (int): The number of output channels.
- `upscale_factor` (int, optional): The upscale factor. Defaults to 2.
- `kernel_size` (int, optional): The size of the kernel. Defaults to 3.
- `stride` (int, optional): The stride of the convolution. Defaults to 1.
- `bias` (bool, optional): Whether to include a bias term. Defaults to True.
- `pad_type` (str, optional): The type of padding. Defaults to "zero".
- `norm_type` (str | None, optional): The type of normalization. Defaults to None.
- `act_type` (str, optional): The type of activation. Defaults to "relu".
- `mode` (str, optional): The mode of upsampling. Defaults to "nearest".
- `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False.
#### Returns:
- `nn.Sequential`: The upsampling convolutional block.
"""
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(
in_nc,
out_nc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
c2x2=c2x2,
)
return sequential(upsample, conv)
class ShortcutBlock(nn.Module):
"""#### Elementwise sum the output of a submodule to its input."""
def __init__(self, submodule: nn.Module):
"""#### Initialize the ShortcutBlock.
#### Args:
- `submodule` (nn.Module): The submodule to apply.
"""
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward pass.
#### Args:
- `x` (torch.Tensor): The input tensor.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
output = x + self.sub(x)
return output