Spaces:
Running
on
Zero
Running
on
Zero
| from collections import OrderedDict | |
| import functools | |
| import math | |
| import re | |
| from typing import Union, Dict | |
| import torch | |
| import torch.nn as nn | |
| from modules.UltimateSDUpscale import USDU_util | |
| class RRDB(nn.Module): | |
| """#### Residual in Residual Dense Block (RRDB) class. | |
| #### Args: | |
| - `nf` (int): Number of filters. | |
| - `kernel_size` (int, optional): Kernel size. Defaults to 3. | |
| - `gc` (int, optional): Growth channel. Defaults to 32. | |
| - `stride` (int, optional): Stride. Defaults to 1. | |
| - `bias` (bool, optional): Whether to use bias. Defaults to True. | |
| - `pad_type` (str, optional): Padding type. Defaults to "zero". | |
| - `norm_type` (str, optional): Normalization type. Defaults to None. | |
| - `act_type` (str, optional): Activation type. Defaults to "leakyrelu". | |
| - `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA". | |
| - `_convtype` (str, optional): Convolution type. Defaults to "Conv2D". | |
| - `_spectral_norm` (bool, optional): Whether to use spectral normalization. Defaults to False. | |
| - `plus` (bool, optional): Whether to use the plus variant. Defaults to False. | |
| - `c2x2` (bool, optional): Whether to use 2x2 convolution. Defaults to False. | |
| """ | |
| def __init__( | |
| self, | |
| nf: int, | |
| kernel_size: int = 3, | |
| gc: int = 32, | |
| stride: int = 1, | |
| bias: bool = True, | |
| pad_type: str = "zero", | |
| norm_type: str = None, | |
| act_type: str = "leakyrelu", | |
| mode: USDU_util.ConvMode = "CNA", | |
| _convtype: str = "Conv2D", | |
| _spectral_norm: bool = False, | |
| plus: bool = False, | |
| c2x2: bool = False, | |
| ) -> None: | |
| super(RRDB, self).__init__() | |
| self.RDB1 = ResidualDenseBlock_5C( | |
| nf, | |
| kernel_size, | |
| gc, | |
| stride, | |
| bias, | |
| pad_type, | |
| norm_type, | |
| act_type, | |
| mode, | |
| plus=plus, | |
| c2x2=c2x2, | |
| ) | |
| self.RDB2 = ResidualDenseBlock_5C( | |
| nf, | |
| kernel_size, | |
| gc, | |
| stride, | |
| bias, | |
| pad_type, | |
| norm_type, | |
| act_type, | |
| mode, | |
| plus=plus, | |
| c2x2=c2x2, | |
| ) | |
| self.RDB3 = ResidualDenseBlock_5C( | |
| nf, | |
| kernel_size, | |
| gc, | |
| stride, | |
| bias, | |
| pad_type, | |
| norm_type, | |
| act_type, | |
| mode, | |
| plus=plus, | |
| c2x2=c2x2, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass of the RRDB. | |
| #### Args: | |
| - `x` (torch.Tensor): Input tensor. | |
| #### Returns: | |
| - `torch.Tensor`: Output tensor. | |
| """ | |
| out = self.RDB1(x) | |
| out = self.RDB2(out) | |
| out = self.RDB3(out) | |
| return out * 0.2 + x | |
| class ResidualDenseBlock_5C(nn.Module): | |
| """#### Residual Dense Block with 5 Convolutions (ResidualDenseBlock_5C) class. | |
| #### Args: | |
| - `nf` (int, optional): Number of filters. Defaults to 64. | |
| - `kernel_size` (int, optional): Kernel size. Defaults to 3. | |
| - `gc` (int, optional): Growth channel. Defaults to 32. | |
| - `stride` (int, optional): Stride. Defaults to 1. | |
| - `bias` (bool, optional): Whether to use bias. Defaults to True. | |
| - `pad_type` (str, optional): Padding type. Defaults to "zero". | |
| - `norm_type` (str, optional): Normalization type. Defaults to None. | |
| - `act_type` (str, optional): Activation type. Defaults to "leakyrelu". | |
| - `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA". | |
| - `plus` (bool, optional): Whether to use the plus variant. Defaults to False. | |
| - `c2x2` (bool, optional): Whether to use 2x2 convolution. Defaults to False. | |
| """ | |
| def __init__( | |
| self, | |
| nf: int = 64, | |
| kernel_size: int = 3, | |
| gc: int = 32, | |
| stride: int = 1, | |
| bias: bool = True, | |
| pad_type: str = "zero", | |
| norm_type: str = None, | |
| act_type: str = "leakyrelu", | |
| mode: USDU_util.ConvMode = "CNA", | |
| plus: bool = False, | |
| c2x2: bool = False, | |
| ) -> None: | |
| super(ResidualDenseBlock_5C, self).__init__() | |
| self.conv1x1 = None | |
| self.conv1 = USDU_util.conv_block( | |
| nf, | |
| gc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| self.conv2 = USDU_util.conv_block( | |
| nf + gc, | |
| gc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| self.conv3 = USDU_util.conv_block( | |
| nf + 2 * gc, | |
| gc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| self.conv4 = USDU_util.conv_block( | |
| nf + 3 * gc, | |
| gc, | |
| kernel_size, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=act_type, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| last_act = None | |
| self.conv5 = USDU_util.conv_block( | |
| nf + 4 * gc, | |
| nf, | |
| 3, | |
| stride, | |
| bias=bias, | |
| pad_type=pad_type, | |
| norm_type=norm_type, | |
| act_type=last_act, | |
| mode=mode, | |
| c2x2=c2x2, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass of the ResidualDenseBlock_5C. | |
| #### Args: | |
| - `x` (torch.Tensor): Input tensor. | |
| #### Returns: | |
| - `torch.Tensor`: Output tensor. | |
| """ | |
| x1 = self.conv1(x) | |
| x2 = self.conv2(torch.cat((x, x1), 1)) | |
| x3 = self.conv3(torch.cat((x, x1, x2), 1)) | |
| x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) | |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
| return x5 * 0.2 + x | |
| class RRDBNet(nn.Module): | |
| """#### Residual in Residual Dense Block Network (RRDBNet) class. | |
| #### Args: | |
| - `state_dict` (dict): State dictionary. | |
| - `norm` (str, optional): Normalization type. Defaults to None. | |
| - `act` (str, optional): Activation type. Defaults to "leakyrelu". | |
| - `upsampler` (str, optional): Upsampler type. Defaults to "upconv". | |
| - `mode` (USDU_util.ConvMode, optional): Convolution mode. Defaults to "CNA". | |
| """ | |
| def __init__( | |
| self, | |
| state_dict: Dict[str, torch.Tensor], | |
| norm: str = None, | |
| act: str = "leakyrelu", | |
| upsampler: str = "upconv", | |
| mode: USDU_util.ConvMode = "CNA", | |
| ) -> None: | |
| super(RRDBNet, self).__init__() | |
| self.model_arch = "ESRGAN" | |
| self.sub_type = "SR" | |
| self.state = state_dict | |
| self.norm = norm | |
| self.act = act | |
| self.upsampler = upsampler | |
| self.mode = mode | |
| self.state_map = { | |
| # currently supports old, new, and newer RRDBNet arch _internal | |
| # ESRGAN, BSRGAN/RealSR, Real-ESRGAN | |
| "model.0.weight": ("conv_first.weight",), | |
| "model.0.bias": ("conv_first.bias",), | |
| "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), | |
| "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), | |
| r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( | |
| r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", | |
| r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", | |
| ), | |
| } | |
| self.num_blocks = self.get_num_blocks() | |
| self.plus = any("conv1x1" in k for k in self.state.keys()) | |
| self.state = self.new_to_old_arch(self.state) | |
| self.key_arr = list(self.state.keys()) | |
| self.in_nc: int = self.state[self.key_arr[0]].shape[1] | |
| self.out_nc: int = self.state[self.key_arr[-1]].shape[0] | |
| self.scale: int = self.get_scale() | |
| self.num_filters: int = self.state[self.key_arr[0]].shape[0] | |
| c2x2 = False | |
| self.supports_fp16 = True | |
| self.supports_bfp16 = True | |
| self.min_size_restriction = None | |
| self.shuffle_factor = None | |
| upsample_block = { | |
| "upconv": USDU_util.upconv_block, | |
| }.get(self.upsampler) | |
| upsample_blocks = [ | |
| upsample_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.num_filters, | |
| act_type=self.act, | |
| c2x2=c2x2, | |
| ) | |
| for _ in range(int(math.log(self.scale, 2))) | |
| ] | |
| self.model = USDU_util.sequential( | |
| # fea conv | |
| USDU_util.conv_block( | |
| in_nc=self.in_nc, | |
| out_nc=self.num_filters, | |
| kernel_size=3, | |
| norm_type=None, | |
| act_type=None, | |
| c2x2=c2x2, | |
| ), | |
| USDU_util.ShortcutBlock( | |
| USDU_util.sequential( | |
| # rrdb blocks | |
| *[ | |
| RRDB( | |
| nf=self.num_filters, | |
| kernel_size=3, | |
| gc=32, | |
| stride=1, | |
| bias=True, | |
| pad_type="zero", | |
| norm_type=self.norm, | |
| act_type=self.act, | |
| mode="CNA", | |
| plus=self.plus, | |
| c2x2=c2x2, | |
| ) | |
| for _ in range(self.num_blocks) | |
| ], | |
| # lr conv | |
| USDU_util.conv_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.num_filters, | |
| kernel_size=3, | |
| norm_type=self.norm, | |
| act_type=None, | |
| mode=self.mode, | |
| c2x2=c2x2, | |
| ), | |
| ) | |
| ), | |
| *upsample_blocks, | |
| # hr_conv0 | |
| USDU_util.conv_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.num_filters, | |
| kernel_size=3, | |
| norm_type=None, | |
| act_type=self.act, | |
| c2x2=c2x2, | |
| ), | |
| # hr_conv1 | |
| USDU_util.conv_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.out_nc, | |
| kernel_size=3, | |
| norm_type=None, | |
| act_type=None, | |
| c2x2=c2x2, | |
| ), | |
| ) | |
| self.load_state_dict(self.state, strict=False) | |
| def new_to_old_arch(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """#### Convert new architecture state dictionary to old architecture. | |
| #### Args: | |
| - `state` (dict): State dictionary. | |
| #### Returns: | |
| - `dict`: Converted state dictionary. | |
| """ | |
| # add nb to state keys | |
| for kind in ("weight", "bias"): | |
| self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ | |
| f"model.1.sub./NB/.{kind}" | |
| ] | |
| del self.state_map[f"model.1.sub./NB/.{kind}"] | |
| old_state = OrderedDict() | |
| for old_key, new_keys in self.state_map.items(): | |
| for new_key in new_keys: | |
| if r"\1" in old_key: | |
| for k, v in state.items(): | |
| sub = re.sub(new_key, old_key, k) | |
| if sub != k: | |
| old_state[sub] = v | |
| else: | |
| if new_key in state: | |
| old_state[old_key] = state[new_key] | |
| # upconv layers | |
| max_upconv = 0 | |
| for key in state.keys(): | |
| match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) | |
| if match is not None: | |
| _, key_num, key_type = match.groups() | |
| old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] | |
| max_upconv = max(max_upconv, int(key_num) * 3) | |
| # final layers | |
| for key in state.keys(): | |
| if key in ("HRconv.weight", "conv_hr.weight"): | |
| old_state[f"model.{max_upconv + 2}.weight"] = state[key] | |
| elif key in ("HRconv.bias", "conv_hr.bias"): | |
| old_state[f"model.{max_upconv + 2}.bias"] = state[key] | |
| elif key in ("conv_last.weight",): | |
| old_state[f"model.{max_upconv + 4}.weight"] = state[key] | |
| elif key in ("conv_last.bias",): | |
| old_state[f"model.{max_upconv + 4}.bias"] = state[key] | |
| # Sort by first numeric value of each layer | |
| def compare(item1: str, item2: str) -> int: | |
| parts1 = item1.split(".") | |
| parts2 = item2.split(".") | |
| int1 = int(parts1[1]) | |
| int2 = int(parts2[1]) | |
| return int1 - int2 | |
| sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) | |
| # Rebuild the output dict in the right order | |
| out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) | |
| return out_dict | |
| def get_scale(self, min_part: int = 6) -> int: | |
| """#### Get the scale factor. | |
| #### Args: | |
| - `min_part` (int, optional): Minimum part. Defaults to 6. | |
| #### Returns: | |
| - `int`: Scale factor. | |
| """ | |
| n = 0 | |
| for part in list(self.state): | |
| parts = part.split(".")[1:] | |
| if len(parts) == 2: | |
| part_num = int(parts[0]) | |
| if part_num > min_part and parts[1] == "weight": | |
| n += 1 | |
| return 2**n | |
| def get_num_blocks(self) -> int: | |
| """#### Get the number of blocks. | |
| #### Returns: | |
| - `int`: Number of blocks. | |
| """ | |
| nbs = [] | |
| state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( | |
| r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", | |
| ) | |
| for state_key in state_keys: | |
| for k in self.state: | |
| m = re.search(state_key, k) | |
| if m: | |
| nbs.append(int(m.group(1))) | |
| if nbs: | |
| break | |
| return max(*nbs) + 1 | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass of the RRDBNet. | |
| #### Args: | |
| - `x` (torch.Tensor): Input tensor. | |
| #### Returns: | |
| - `torch.Tensor`: Output tensor. | |
| """ | |
| return self.model(x) | |
| PyTorchSRModels = (RRDBNet,) | |
| PyTorchSRModel = Union[RRDBNet,] | |
| PyTorchModels = (*PyTorchSRModels,) | |
| PyTorchModel = Union[PyTorchSRModel] |