Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
from collections import OrderedDict
import functools
import math
import re
from typing import Union, Dict
import torch
import torch.nn as nn
from modules.UltimateSDUpscale import USDU_util
class RRDB(nn.Module):
"""#### Residual in Residual Dense Block (RRDB) class.
#### Args:
- `nf` (int): Number of filters.
- `kernel_size` (int, optional): Kernel size. Defaults to 3.
- `gc` (int, optional): Growth channel. Defaults to 32.
- `stride` (int, optional): Stride. Defaults to 1.
- `bias` (bool, optional): Whether to use bias. Defaults to True.
- `pad_type` (str, optional): Padding type. Defaults to "zero".
- `norm_type` (str, optional): Normalization type. Defaults to None.
- `act_type` (str, optional): Activation type. Defaults to "leakyrelu".
- `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA".
- `_convtype` (str, optional): Convolution type. Defaults to "Conv2D".
- `_spectral_norm` (bool, optional): Whether to use spectral normalization. Defaults to False.
- `plus` (bool, optional): Whether to use the plus variant. Defaults to False.
- `c2x2` (bool, optional): Whether to use 2x2 convolution. Defaults to False.
"""
def __init__(
self,
nf: int,
kernel_size: int = 3,
gc: int = 32,
stride: int = 1,
bias: bool = True,
pad_type: str = "zero",
norm_type: str = None,
act_type: str = "leakyrelu",
mode: USDU_util.ConvMode = "CNA",
_convtype: str = "Conv2D",
_spectral_norm: bool = False,
plus: bool = False,
c2x2: bool = False,
) -> None:
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(
nf,
kernel_size,
gc,
stride,
bias,
pad_type,
norm_type,
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
self.RDB2 = ResidualDenseBlock_5C(
nf,
kernel_size,
gc,
stride,
bias,
pad_type,
norm_type,
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
self.RDB3 = ResidualDenseBlock_5C(
nf,
kernel_size,
gc,
stride,
bias,
pad_type,
norm_type,
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward pass of the RRDB.
#### Args:
- `x` (torch.Tensor): Input tensor.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class ResidualDenseBlock_5C(nn.Module):
"""#### Residual Dense Block with 5 Convolutions (ResidualDenseBlock_5C) class.
#### Args:
- `nf` (int, optional): Number of filters. Defaults to 64.
- `kernel_size` (int, optional): Kernel size. Defaults to 3.
- `gc` (int, optional): Growth channel. Defaults to 32.
- `stride` (int, optional): Stride. Defaults to 1.
- `bias` (bool, optional): Whether to use bias. Defaults to True.
- `pad_type` (str, optional): Padding type. Defaults to "zero".
- `norm_type` (str, optional): Normalization type. Defaults to None.
- `act_type` (str, optional): Activation type. Defaults to "leakyrelu".
- `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA".
- `plus` (bool, optional): Whether to use the plus variant. Defaults to False.
- `c2x2` (bool, optional): Whether to use 2x2 convolution. Defaults to False.
"""
def __init__(
self,
nf: int = 64,
kernel_size: int = 3,
gc: int = 32,
stride: int = 1,
bias: bool = True,
pad_type: str = "zero",
norm_type: str = None,
act_type: str = "leakyrelu",
mode: USDU_util.ConvMode = "CNA",
plus: bool = False,
c2x2: bool = False,
) -> None:
super(ResidualDenseBlock_5C, self).__init__()
self.conv1x1 = None
self.conv1 = USDU_util.conv_block(
nf,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv2 = USDU_util.conv_block(
nf + gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv3 = USDU_util.conv_block(
nf + 2 * gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv4 = USDU_util.conv_block(
nf + 3 * gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
last_act = None
self.conv5 = USDU_util.conv_block(
nf + 4 * gc,
nf,
3,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=last_act,
mode=mode,
c2x2=c2x2,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward pass of the ResidualDenseBlock_5C.
#### Args:
- `x` (torch.Tensor): Input tensor.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDBNet(nn.Module):
"""#### Residual in Residual Dense Block Network (RRDBNet) class.
#### Args:
- `state_dict` (dict): State dictionary.
- `norm` (str, optional): Normalization type. Defaults to None.
- `act` (str, optional): Activation type. Defaults to "leakyrelu".
- `upsampler` (str, optional): Upsampler type. Defaults to "upconv".
- `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA".
"""
def __init__(
self,
state_dict: Dict[str, torch.Tensor],
norm: str = None,
act: str = "leakyrelu",
upsampler: str = "upconv",
mode: USDU_util.ConvMode = "CNA",
) -> None:
super(RRDBNet, self).__init__()
self.model_arch = "ESRGAN"
self.sub_type = "SR"
self.state = state_dict
self.norm = norm
self.act = act
self.upsampler = upsampler
self.mode = mode
self.state_map = {
# currently supports old, new, and newer RRDBNet arch _internal
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
"model.0.weight": ("conv_first.weight",),
"model.0.bias": ("conv_first.bias",),
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
),
}
self.num_blocks = self.get_num_blocks()
self.plus = any("conv1x1" in k for k in self.state.keys())
self.state = self.new_to_old_arch(self.state)
self.key_arr = list(self.state.keys())
self.in_nc: int = self.state[self.key_arr[0]].shape[1]
self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
self.scale: int = self.get_scale()
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
c2x2 = False
self.supports_fp16 = True
self.supports_bfp16 = True
self.min_size_restriction = None
self.shuffle_factor = None
upsample_block = {
"upconv": USDU_util.upconv_block,
}.get(self.upsampler)
upsample_blocks = [
upsample_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
act_type=self.act,
c2x2=c2x2,
)
for _ in range(int(math.log(self.scale, 2)))
]
self.model = USDU_util.sequential(
# fea conv
USDU_util.conv_block(
in_nc=self.in_nc,
out_nc=self.num_filters,
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
USDU_util.ShortcutBlock(
USDU_util.sequential(
# rrdb blocks
*[
RRDB(
nf=self.num_filters,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=self.norm,
act_type=self.act,
mode="CNA",
plus=self.plus,
c2x2=c2x2,
)
for _ in range(self.num_blocks)
],
# lr conv
USDU_util.conv_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
kernel_size=3,
norm_type=self.norm,
act_type=None,
mode=self.mode,
c2x2=c2x2,
),
)
),
*upsample_blocks,
# hr_conv0
USDU_util.conv_block(
in_nc=self.num_filters,
out_nc=self.num_filters,
kernel_size=3,
norm_type=None,
act_type=self.act,
c2x2=c2x2,
),
# hr_conv1
USDU_util.conv_block(
in_nc=self.num_filters,
out_nc=self.out_nc,
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
)
self.load_state_dict(self.state, strict=False)
def new_to_old_arch(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""#### Convert new architecture state dictionary to old architecture.
#### Args:
- `state` (dict): State dictionary.
#### Returns:
- `dict`: Converted state dictionary.
"""
# add nb to state keys
for kind in ("weight", "bias"):
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
f"model.1.sub./NB/.{kind}"
]
del self.state_map[f"model.1.sub./NB/.{kind}"]
old_state = OrderedDict()
for old_key, new_keys in self.state_map.items():
for new_key in new_keys:
if r"\1" in old_key:
for k, v in state.items():
sub = re.sub(new_key, old_key, k)
if sub != k:
old_state[sub] = v
else:
if new_key in state:
old_state[old_key] = state[new_key]
# upconv layers
max_upconv = 0
for key in state.keys():
match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
if match is not None:
_, key_num, key_type = match.groups()
old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
max_upconv = max(max_upconv, int(key_num) * 3)
# final layers
for key in state.keys():
if key in ("HRconv.weight", "conv_hr.weight"):
old_state[f"model.{max_upconv + 2}.weight"] = state[key]
elif key in ("HRconv.bias", "conv_hr.bias"):
old_state[f"model.{max_upconv + 2}.bias"] = state[key]
elif key in ("conv_last.weight",):
old_state[f"model.{max_upconv + 4}.weight"] = state[key]
elif key in ("conv_last.bias",):
old_state[f"model.{max_upconv + 4}.bias"] = state[key]
# Sort by first numeric value of each layer
def compare(item1: str, item2: str) -> int:
parts1 = item1.split(".")
parts2 = item2.split(".")
int1 = int(parts1[1])
int2 = int(parts2[1])
return int1 - int2
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
# Rebuild the output dict in the right order
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
return out_dict
def get_scale(self, min_part: int = 6) -> int:
"""#### Get the scale factor.
#### Args:
- `min_part` (int, optional): Minimum part. Defaults to 6.
#### Returns:
- `int`: Scale factor.
"""
n = 0
for part in list(self.state):
parts = part.split(".")[1:]
if len(parts) == 2:
part_num = int(parts[0])
if part_num > min_part and parts[1] == "weight":
n += 1
return 2**n
def get_num_blocks(self) -> int:
"""#### Get the number of blocks.
#### Returns:
- `int`: Number of blocks.
"""
nbs = []
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
)
for state_key in state_keys:
for k in self.state:
m = re.search(state_key, k)
if m:
nbs.append(int(m.group(1)))
if nbs:
break
return max(*nbs) + 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward pass of the RRDBNet.
#### Args:
- `x` (torch.Tensor): Input tensor.
#### Returns:
- `torch.Tensor`: Output tensor.
"""
return self.model(x)
PyTorchSRModels = (RRDBNet,)
PyTorchSRModel = Union[RRDBNet,]
PyTorchModels = (*PyTorchSRModels,)
PyTorchModel = Union[PyTorchSRModel]