Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,128 Bytes
d9a2e19 1d117d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from typing import Literal
import torch
import torch.nn as nn
ConvMode = Literal["CNA", "NAC", "CNAC"]
def act(act_type: str, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1) -> nn.Module:
"""#### Get the activation layer.
#### Args:
- `act_type` (str): The type of activation.
- `inplace` (bool, optional): Whether to perform the operation in-place. Defaults to True.
- `neg_slope` (float, optional): The negative slope for LeakyReLU. Defaults to 0.2.
- `n_prelu` (int, optional): The number of PReLU parameters. Defaults to 1.
#### Returns:
- `nn.Module`: The activation layer.
"""
act_type = act_type.lower()
layer = nn.LeakyReLU(neg_slope, inplace)
return layer
def get_valid_padding(kernel_size: int, dilation: int) -> int:
"""#### Get the valid padding for a convolutional layer.
#### Args:
- `kernel_size` (int): The size of the kernel.
- `dilation` (int): The dilation rate.
#### Returns:
- `int`: The valid padding.
"""
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
def sequential(*args: nn.Module) -> nn.Sequential:
"""#### Create a sequential container.
#### Args:
- `*args` (nn.Module): The modules to include in the sequential container.
#### Returns:
- `nn.Sequential`: The sequential container.
"""
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv_block(
in_nc: int,
out_nc: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
pad_type: str = "zero",
norm_type: str | None = None,
act_type: str | None = "relu",
mode: ConvMode = "CNA",
c2x2: bool = False,
) -> nn.Sequential:
"""#### Create a convolutional block.
#### Args:
- `in_nc` (int): The number of input channels.
- `out_nc` (int): The number of output channels.
- `kernel_size` (int): The size of the kernel.
- `stride` (int, optional): The stride of the convolution. Defaults to 1.
- `dilation` (int, optional): The dilation rate. Defaults to 1.
- `groups` (int, optional): The number of groups. Defaults to 1.
- `bias` (bool, optional): Whether to include a bias term. Defaults to True.
- `pad_type` (str, optional): The type of padding. Defaults to "zero".
- `norm_type` (str | None, optional): The type of normalization. Defaults to None.
- `act_type` (str | None, optional): The type of activation. Defaults to "relu".
- `mode` (ConvMode, optional): The mode of the convolution. Defaults to "CNA".
- `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False.
#### Returns:
- `nn.Sequential`: The convolutional block.
"""
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
padding = get_valid_padding(kernel_size, dilation)
padding = padding if pad_type == "zero" else 0
c = nn.Conv2d(
in_nc,
out_nc,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
groups=groups,
)
a = act(act_type) if act_type else None
if mode in ("CNA", "CNAC"):
return sequential(None, c, None, a)
def upconv_block(
in_nc: int,
out_nc: int,
upscale_factor: int = 2,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
pad_type: str = "zero",
norm_type: str | None = None,
act_type: str = "relu",
mode: str = "nearest",
c2x2: bool = False,
) -> nn.Sequential:
"""#### Create an upsampling convolutional block.
#### Args:
- `in_nc` (int): The number of input channels.
- `out_nc` (int): The number of output channels.
- `upscale_factor` (int, optional): The upscale factor. Defaults to 2.
- `kernel_size` (int, optional): The size of the kernel. Defaults to 3.
- `stride` (int, optional): The stride of the convolution. Defaults to 1.
- `bias` (bool, optional): Whether to include a bias term. Defaults to True.
- `pad_type` (str, optional): The type of padding. Defaults to "zero".
- `norm_type` (str | None, optional): The type of normalization. Defaults to None.
- `act_type` (str, optional): The type of activation. Defaults to "relu".
- `mode` (str, optional): The mode of upsampling. Defaults to "nearest".
- `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False.
#### Returns:
- `nn.Sequential`: The upsampling convolutional block.
"""
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(
in_nc,
out_nc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
c2x2=c2x2,
)
return sequential(upsample, conv)
class ShortcutBlock(nn.Module):
"""#### Elementwise sum the output of a submodule to its input."""
def __init__(self, submodule: nn.Module):
"""#### Initialize the ShortcutBlock.
#### Args:
- `submodule` (nn.Module): The submodule to apply.
"""
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward pass.
#### Args:
- `x` (torch.Tensor): The input tensor.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
output = x + self.sub(x)
return output |