from collections import OrderedDict import functools import math import re from typing import Union, Dict import torch import torch.nn as nn from modules.UltimateSDUpscale import USDU_util class RRDB(nn.Module): """#### Residual in Residual Dense Block (RRDB) class. #### Args: - `nf` (int): Number of filters. - `kernel_size` (int, optional): Kernel size. Defaults to 3. - `gc` (int, optional): Growth channel. Defaults to 32. - `stride` (int, optional): Stride. Defaults to 1. - `bias` (bool, optional): Whether to use bias. Defaults to True. - `pad_type` (str, optional): Padding type. Defaults to "zero". - `norm_type` (str, optional): Normalization type. Defaults to None. - `act_type` (str, optional): Activation type. Defaults to "leakyrelu". - `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA". - `_convtype` (str, optional): Convolution type. Defaults to "Conv2D". - `_spectral_norm` (bool, optional): Whether to use spectral normalization. Defaults to False. - `plus` (bool, optional): Whether to use the plus variant. Defaults to False. - `c2x2` (bool, optional): Whether to use 2x2 convolution. Defaults to False. """ def __init__( self, nf: int, kernel_size: int = 3, gc: int = 32, stride: int = 1, bias: bool = True, pad_type: str = "zero", norm_type: str = None, act_type: str = "leakyrelu", mode: USDU_util.ConvMode = "CNA", _convtype: str = "Conv2D", _spectral_norm: bool = False, plus: bool = False, c2x2: bool = False, ) -> None: super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C( nf, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode, plus=plus, c2x2=c2x2, ) self.RDB2 = ResidualDenseBlock_5C( nf, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode, plus=plus, c2x2=c2x2, ) self.RDB3 = ResidualDenseBlock_5C( nf, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode, plus=plus, c2x2=c2x2, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """#### Forward pass of the RRDB. #### Args: - `x` (torch.Tensor): Input tensor. #### Returns: - `torch.Tensor`: Output tensor. """ out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x class ResidualDenseBlock_5C(nn.Module): """#### Residual Dense Block with 5 Convolutions (ResidualDenseBlock_5C) class. #### Args: - `nf` (int, optional): Number of filters. Defaults to 64. - `kernel_size` (int, optional): Kernel size. Defaults to 3. - `gc` (int, optional): Growth channel. Defaults to 32. - `stride` (int, optional): Stride. Defaults to 1. - `bias` (bool, optional): Whether to use bias. Defaults to True. - `pad_type` (str, optional): Padding type. Defaults to "zero". - `norm_type` (str, optional): Normalization type. Defaults to None. - `act_type` (str, optional): Activation type. Defaults to "leakyrelu". - `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA". - `plus` (bool, optional): Whether to use the plus variant. Defaults to False. - `c2x2` (bool, optional): Whether to use 2x2 convolution. Defaults to False. """ def __init__( self, nf: int = 64, kernel_size: int = 3, gc: int = 32, stride: int = 1, bias: bool = True, pad_type: str = "zero", norm_type: str = None, act_type: str = "leakyrelu", mode: USDU_util.ConvMode = "CNA", plus: bool = False, c2x2: bool = False, ) -> None: super(ResidualDenseBlock_5C, self).__init__() self.conv1x1 = None self.conv1 = USDU_util.conv_block( nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode, c2x2=c2x2, ) self.conv2 = USDU_util.conv_block( nf + gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode, c2x2=c2x2, ) self.conv3 = USDU_util.conv_block( nf + 2 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode, c2x2=c2x2, ) self.conv4 = USDU_util.conv_block( nf + 3 * gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode, c2x2=c2x2, ) last_act = None self.conv5 = USDU_util.conv_block( nf + 4 * gc, nf, 3, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode, c2x2=c2x2, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """#### Forward pass of the ResidualDenseBlock_5C. #### Args: - `x` (torch.Tensor): Input tensor. #### Returns: - `torch.Tensor`: Output tensor. """ x1 = self.conv1(x) x2 = self.conv2(torch.cat((x, x1), 1)) x3 = self.conv3(torch.cat((x, x1, x2), 1)) x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDBNet(nn.Module): """#### Residual in Residual Dense Block Network (RRDBNet) class. #### Args: - `state_dict` (dict): State dictionary. - `norm` (str, optional): Normalization type. Defaults to None. - `act` (str, optional): Activation type. Defaults to "leakyrelu". - `upsampler` (str, optional): Upsampler type. Defaults to "upconv". - `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA". """ def __init__( self, state_dict: Dict[str, torch.Tensor], norm: str = None, act: str = "leakyrelu", upsampler: str = "upconv", mode: USDU_util.ConvMode = "CNA", ) -> None: super(RRDBNet, self).__init__() self.model_arch = "ESRGAN" self.sub_type = "SR" self.state = state_dict self.norm = norm self.act = act self.upsampler = upsampler self.mode = mode self.state_map = { # currently supports old, new, and newer RRDBNet arch _internal # ESRGAN, BSRGAN/RealSR, Real-ESRGAN "model.0.weight": ("conv_first.weight",), "model.0.bias": ("conv_first.bias",), "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", ), } self.num_blocks = self.get_num_blocks() self.plus = any("conv1x1" in k for k in self.state.keys()) self.state = self.new_to_old_arch(self.state) self.key_arr = list(self.state.keys()) self.in_nc: int = self.state[self.key_arr[0]].shape[1] self.out_nc: int = self.state[self.key_arr[-1]].shape[0] self.scale: int = self.get_scale() self.num_filters: int = self.state[self.key_arr[0]].shape[0] c2x2 = False self.supports_fp16 = True self.supports_bfp16 = True self.min_size_restriction = None self.shuffle_factor = None upsample_block = { "upconv": USDU_util.upconv_block, }.get(self.upsampler) upsample_blocks = [ upsample_block( in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act, c2x2=c2x2, ) for _ in range(int(math.log(self.scale, 2))) ] self.model = USDU_util.sequential( # fea conv USDU_util.conv_block( in_nc=self.in_nc, out_nc=self.num_filters, kernel_size=3, norm_type=None, act_type=None, c2x2=c2x2, ), USDU_util.ShortcutBlock( USDU_util.sequential( # rrdb blocks *[ RRDB( nf=self.num_filters, kernel_size=3, gc=32, stride=1, bias=True, pad_type="zero", norm_type=self.norm, act_type=self.act, mode="CNA", plus=self.plus, c2x2=c2x2, ) for _ in range(self.num_blocks) ], # lr conv USDU_util.conv_block( in_nc=self.num_filters, out_nc=self.num_filters, kernel_size=3, norm_type=self.norm, act_type=None, mode=self.mode, c2x2=c2x2, ), ) ), *upsample_blocks, # hr_conv0 USDU_util.conv_block( in_nc=self.num_filters, out_nc=self.num_filters, kernel_size=3, norm_type=None, act_type=self.act, c2x2=c2x2, ), # hr_conv1 USDU_util.conv_block( in_nc=self.num_filters, out_nc=self.out_nc, kernel_size=3, norm_type=None, act_type=None, c2x2=c2x2, ), ) self.load_state_dict(self.state, strict=False) def new_to_old_arch(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """#### Convert new architecture state dictionary to old architecture. #### Args: - `state` (dict): State dictionary. #### Returns: - `dict`: Converted state dictionary. """ # add nb to state keys for kind in ("weight", "bias"): self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ f"model.1.sub./NB/.{kind}" ] del self.state_map[f"model.1.sub./NB/.{kind}"] old_state = OrderedDict() for old_key, new_keys in self.state_map.items(): for new_key in new_keys: if r"\1" in old_key: for k, v in state.items(): sub = re.sub(new_key, old_key, k) if sub != k: old_state[sub] = v else: if new_key in state: old_state[old_key] = state[new_key] # upconv layers max_upconv = 0 for key in state.keys(): match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) if match is not None: _, key_num, key_type = match.groups() old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] max_upconv = max(max_upconv, int(key_num) * 3) # final layers for key in state.keys(): if key in ("HRconv.weight", "conv_hr.weight"): old_state[f"model.{max_upconv + 2}.weight"] = state[key] elif key in ("HRconv.bias", "conv_hr.bias"): old_state[f"model.{max_upconv + 2}.bias"] = state[key] elif key in ("conv_last.weight",): old_state[f"model.{max_upconv + 4}.weight"] = state[key] elif key in ("conv_last.bias",): old_state[f"model.{max_upconv + 4}.bias"] = state[key] # Sort by first numeric value of each layer def compare(item1: str, item2: str) -> int: parts1 = item1.split(".") parts2 = item2.split(".") int1 = int(parts1[1]) int2 = int(parts2[1]) return int1 - int2 sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) # Rebuild the output dict in the right order out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) return out_dict def get_scale(self, min_part: int = 6) -> int: """#### Get the scale factor. #### Args: - `min_part` (int, optional): Minimum part. Defaults to 6. #### Returns: - `int`: Scale factor. """ n = 0 for part in list(self.state): parts = part.split(".")[1:] if len(parts) == 2: part_num = int(parts[0]) if part_num > min_part and parts[1] == "weight": n += 1 return 2**n def get_num_blocks(self) -> int: """#### Get the number of blocks. #### Returns: - `int`: Number of blocks. """ nbs = [] state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", ) for state_key in state_keys: for k in self.state: m = re.search(state_key, k) if m: nbs.append(int(m.group(1))) if nbs: break return max(*nbs) + 1 def forward(self, x: torch.Tensor) -> torch.Tensor: """#### Forward pass of the RRDBNet. #### Args: - `x` (torch.Tensor): Input tensor. #### Returns: - `torch.Tensor`: Output tensor. """ return self.model(x) PyTorchSRModels = (RRDBNet,) PyTorchSRModel = Union[RRDBNet,] PyTorchModels = (*PyTorchSRModels,) PyTorchModel = Union[PyTorchSRModel]