Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| import functools | |
| import math | |
| import re | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from . import block as B | |
| esrgan_safetensors_keys = ['model.0.weight', 'model.0.bias', 'model.1.sub.0.RDB1.conv1.0.weight', | |
| 'model.1.sub.0.RDB1.conv1.0.bias', 'model.1.sub.0.RDB1.conv2.0.weight', | |
| 'model.1.sub.0.RDB1.conv2.0.bias', 'model.1.sub.0.RDB1.conv3.0.weight', | |
| 'model.1.sub.0.RDB1.conv3.0.bias', 'model.1.sub.0.RDB1.conv4.0.weight', | |
| 'model.1.sub.0.RDB1.conv4.0.bias', 'model.1.sub.0.RDB1.conv5.0.weight', | |
| 'model.1.sub.0.RDB1.conv5.0.bias', 'model.1.sub.0.RDB2.conv1.0.weight', | |
| 'model.1.sub.0.RDB2.conv1.0.bias', 'model.1.sub.0.RDB2.conv2.0.weight', | |
| 'model.1.sub.0.RDB2.conv2.0.bias', 'model.1.sub.0.RDB2.conv3.0.weight', | |
| 'model.1.sub.0.RDB2.conv3.0.bias', 'model.1.sub.0.RDB2.conv4.0.weight', | |
| 'model.1.sub.0.RDB2.conv4.0.bias', 'model.1.sub.0.RDB2.conv5.0.weight', | |
| 'model.1.sub.0.RDB2.conv5.0.bias', 'model.1.sub.0.RDB3.conv1.0.weight', | |
| 'model.1.sub.0.RDB3.conv1.0.bias', 'model.1.sub.0.RDB3.conv2.0.weight', | |
| 'model.1.sub.0.RDB3.conv2.0.bias', 'model.1.sub.0.RDB3.conv3.0.weight', | |
| 'model.1.sub.0.RDB3.conv3.0.bias', 'model.1.sub.0.RDB3.conv4.0.weight', | |
| 'model.1.sub.0.RDB3.conv4.0.bias', 'model.1.sub.0.RDB3.conv5.0.weight', | |
| 'model.1.sub.0.RDB3.conv5.0.bias', 'model.1.sub.1.RDB1.conv1.0.weight', | |
| 'model.1.sub.1.RDB1.conv1.0.bias', 'model.1.sub.1.RDB1.conv2.0.weight', | |
| 'model.1.sub.1.RDB1.conv2.0.bias', 'model.1.sub.1.RDB1.conv3.0.weight', | |
| 'model.1.sub.1.RDB1.conv3.0.bias', 'model.1.sub.1.RDB1.conv4.0.weight', | |
| 'model.1.sub.1.RDB1.conv4.0.bias', 'model.1.sub.1.RDB1.conv5.0.weight', | |
| 'model.1.sub.1.RDB1.conv5.0.bias', 'model.1.sub.1.RDB2.conv1.0.weight', | |
| 'model.1.sub.1.RDB2.conv1.0.bias', 'model.1.sub.1.RDB2.conv2.0.weight', | |
| 'model.1.sub.1.RDB2.conv2.0.bias', 'model.1.sub.1.RDB2.conv3.0.weight', | |
| 'model.1.sub.1.RDB2.conv3.0.bias', 'model.1.sub.1.RDB2.conv4.0.weight', | |
| 'model.1.sub.1.RDB2.conv4.0.bias', 'model.1.sub.1.RDB2.conv5.0.weight', | |
| 'model.1.sub.1.RDB2.conv5.0.bias', 'model.1.sub.1.RDB3.conv1.0.weight', | |
| 'model.1.sub.1.RDB3.conv1.0.bias', 'model.1.sub.1.RDB3.conv2.0.weight', | |
| 'model.1.sub.1.RDB3.conv2.0.bias', 'model.1.sub.1.RDB3.conv3.0.weight', | |
| 'model.1.sub.1.RDB3.conv3.0.bias', 'model.1.sub.1.RDB3.conv4.0.weight', | |
| 'model.1.sub.1.RDB3.conv4.0.bias', 'model.1.sub.1.RDB3.conv5.0.weight', | |
| 'model.1.sub.1.RDB3.conv5.0.bias', 'model.1.sub.2.RDB1.conv1.0.weight', | |
| 'model.1.sub.2.RDB1.conv1.0.bias', 'model.1.sub.2.RDB1.conv2.0.weight', | |
| 'model.1.sub.2.RDB1.conv2.0.bias', 'model.1.sub.2.RDB1.conv3.0.weight', | |
| 'model.1.sub.2.RDB1.conv3.0.bias', 'model.1.sub.2.RDB1.conv4.0.weight', | |
| 'model.1.sub.2.RDB1.conv4.0.bias', 'model.1.sub.2.RDB1.conv5.0.weight', | |
| 'model.1.sub.2.RDB1.conv5.0.bias', 'model.1.sub.2.RDB2.conv1.0.weight', | |
| 'model.1.sub.2.RDB2.conv1.0.bias', 'model.1.sub.2.RDB2.conv2.0.weight', | |
| 'model.1.sub.2.RDB2.conv2.0.bias', 'model.1.sub.2.RDB2.conv3.0.weight', | |
| 'model.1.sub.2.RDB2.conv3.0.bias', 'model.1.sub.2.RDB2.conv4.0.weight', | |
| 'model.1.sub.2.RDB2.conv4.0.bias', 'model.1.sub.2.RDB2.conv5.0.weight', | |
| 'model.1.sub.2.RDB2.conv5.0.bias', 'model.1.sub.2.RDB3.conv1.0.weight', | |
| 'model.1.sub.2.RDB3.conv1.0.bias', 'model.1.sub.2.RDB3.conv2.0.weight', | |
| 'model.1.sub.2.RDB3.conv2.0.bias', 'model.1.sub.2.RDB3.conv3.0.weight', | |
| 'model.1.sub.2.RDB3.conv3.0.bias', 'model.1.sub.2.RDB3.conv4.0.weight', | |
| 'model.1.sub.2.RDB3.conv4.0.bias', 'model.1.sub.2.RDB3.conv5.0.weight', | |
| 'model.1.sub.2.RDB3.conv5.0.bias', 'model.1.sub.3.RDB1.conv1.0.weight', | |
| 'model.1.sub.3.RDB1.conv1.0.bias', 'model.1.sub.3.RDB1.conv2.0.weight', | |
| 'model.1.sub.3.RDB1.conv2.0.bias', 'model.1.sub.3.RDB1.conv3.0.weight', | |
| 'model.1.sub.3.RDB1.conv3.0.bias', 'model.1.sub.3.RDB1.conv4.0.weight', | |
| 'model.1.sub.3.RDB1.conv4.0.bias', 'model.1.sub.3.RDB1.conv5.0.weight', | |
| 'model.1.sub.3.RDB1.conv5.0.bias', 'model.1.sub.3.RDB2.conv1.0.weight', | |
| 'model.1.sub.3.RDB2.conv1.0.bias', 'model.1.sub.3.RDB2.conv2.0.weight', | |
| 'model.1.sub.3.RDB2.conv2.0.bias', 'model.1.sub.3.RDB2.conv3.0.weight', | |
| 'model.1.sub.3.RDB2.conv3.0.bias', 'model.1.sub.3.RDB2.conv4.0.weight', | |
| 'model.1.sub.3.RDB2.conv4.0.bias', 'model.1.sub.3.RDB2.conv5.0.weight', | |
| 'model.1.sub.3.RDB2.conv5.0.bias', 'model.1.sub.3.RDB3.conv1.0.weight', | |
| 'model.1.sub.3.RDB3.conv1.0.bias', 'model.1.sub.3.RDB3.conv2.0.weight', | |
| 'model.1.sub.3.RDB3.conv2.0.bias', 'model.1.sub.3.RDB3.conv3.0.weight', | |
| 'model.1.sub.3.RDB3.conv3.0.bias', 'model.1.sub.3.RDB3.conv4.0.weight', | |
| 'model.1.sub.3.RDB3.conv4.0.bias', 'model.1.sub.3.RDB3.conv5.0.weight', | |
| 'model.1.sub.3.RDB3.conv5.0.bias', 'model.1.sub.4.RDB1.conv1.0.weight', | |
| 'model.1.sub.4.RDB1.conv1.0.bias', 'model.1.sub.4.RDB1.conv2.0.weight', | |
| 'model.1.sub.4.RDB1.conv2.0.bias', 'model.1.sub.4.RDB1.conv3.0.weight', | |
| 'model.1.sub.4.RDB1.conv3.0.bias', 'model.1.sub.4.RDB1.conv4.0.weight', | |
| 'model.1.sub.4.RDB1.conv4.0.bias', 'model.1.sub.4.RDB1.conv5.0.weight', | |
| 'model.1.sub.4.RDB1.conv5.0.bias', 'model.1.sub.4.RDB2.conv1.0.weight', | |
| 'model.1.sub.4.RDB2.conv1.0.bias', 'model.1.sub.4.RDB2.conv2.0.weight', | |
| 'model.1.sub.4.RDB2.conv2.0.bias', 'model.1.sub.4.RDB2.conv3.0.weight', | |
| 'model.1.sub.4.RDB2.conv3.0.bias', 'model.1.sub.4.RDB2.conv4.0.weight', | |
| 'model.1.sub.4.RDB2.conv4.0.bias', 'model.1.sub.4.RDB2.conv5.0.weight', | |
| 'model.1.sub.4.RDB2.conv5.0.bias', 'model.1.sub.4.RDB3.conv1.0.weight', | |
| 'model.1.sub.4.RDB3.conv1.0.bias', 'model.1.sub.4.RDB3.conv2.0.weight', | |
| 'model.1.sub.4.RDB3.conv2.0.bias', 'model.1.sub.4.RDB3.conv3.0.weight', | |
| 'model.1.sub.4.RDB3.conv3.0.bias', 'model.1.sub.4.RDB3.conv4.0.weight', | |
| 'model.1.sub.4.RDB3.conv4.0.bias', 'model.1.sub.4.RDB3.conv5.0.weight', | |
| 'model.1.sub.4.RDB3.conv5.0.bias', 'model.1.sub.5.RDB1.conv1.0.weight', | |
| 'model.1.sub.5.RDB1.conv1.0.bias', 'model.1.sub.5.RDB1.conv2.0.weight', | |
| 'model.1.sub.5.RDB1.conv2.0.bias', 'model.1.sub.5.RDB1.conv3.0.weight', | |
| 'model.1.sub.5.RDB1.conv3.0.bias', 'model.1.sub.5.RDB1.conv4.0.weight', | |
| 'model.1.sub.5.RDB1.conv4.0.bias', 'model.1.sub.5.RDB1.conv5.0.weight', | |
| 'model.1.sub.5.RDB1.conv5.0.bias', 'model.1.sub.5.RDB2.conv1.0.weight', | |
| 'model.1.sub.5.RDB2.conv1.0.bias', 'model.1.sub.5.RDB2.conv2.0.weight', | |
| 'model.1.sub.5.RDB2.conv2.0.bias', 'model.1.sub.5.RDB2.conv3.0.weight', | |
| 'model.1.sub.5.RDB2.conv3.0.bias', 'model.1.sub.5.RDB2.conv4.0.weight', | |
| 'model.1.sub.5.RDB2.conv4.0.bias', 'model.1.sub.5.RDB2.conv5.0.weight', | |
| 'model.1.sub.5.RDB2.conv5.0.bias', 'model.1.sub.5.RDB3.conv1.0.weight', | |
| 'model.1.sub.5.RDB3.conv1.0.bias', 'model.1.sub.5.RDB3.conv2.0.weight', | |
| 'model.1.sub.5.RDB3.conv2.0.bias', 'model.1.sub.5.RDB3.conv3.0.weight', | |
| 'model.1.sub.5.RDB3.conv3.0.bias', 'model.1.sub.5.RDB3.conv4.0.weight', | |
| 'model.1.sub.5.RDB3.conv4.0.bias', 'model.1.sub.5.RDB3.conv5.0.weight', | |
| 'model.1.sub.5.RDB3.conv5.0.bias', 'model.1.sub.6.RDB1.conv1.0.weight', | |
| 'model.1.sub.6.RDB1.conv1.0.bias', 'model.1.sub.6.RDB1.conv2.0.weight', | |
| 'model.1.sub.6.RDB1.conv2.0.bias', 'model.1.sub.6.RDB1.conv3.0.weight', | |
| 'model.1.sub.6.RDB1.conv3.0.bias', 'model.1.sub.6.RDB1.conv4.0.weight', | |
| 'model.1.sub.6.RDB1.conv4.0.bias', 'model.1.sub.6.RDB1.conv5.0.weight', | |
| 'model.1.sub.6.RDB1.conv5.0.bias', 'model.1.sub.6.RDB2.conv1.0.weight', | |
| 'model.1.sub.6.RDB2.conv1.0.bias', 'model.1.sub.6.RDB2.conv2.0.weight', | |
| 'model.1.sub.6.RDB2.conv2.0.bias', 'model.1.sub.6.RDB2.conv3.0.weight', | |
| 'model.1.sub.6.RDB2.conv3.0.bias', 'model.1.sub.6.RDB2.conv4.0.weight', | |
| 'model.1.sub.6.RDB2.conv4.0.bias', 'model.1.sub.6.RDB2.conv5.0.weight', | |
| 'model.1.sub.6.RDB2.conv5.0.bias', 'model.1.sub.6.RDB3.conv1.0.weight', | |
| 'model.1.sub.6.RDB3.conv1.0.bias', 'model.1.sub.6.RDB3.conv2.0.weight', | |
| 'model.1.sub.6.RDB3.conv2.0.bias', 'model.1.sub.6.RDB3.conv3.0.weight', | |
| 'model.1.sub.6.RDB3.conv3.0.bias', 'model.1.sub.6.RDB3.conv4.0.weight', | |
| 'model.1.sub.6.RDB3.conv4.0.bias', 'model.1.sub.6.RDB3.conv5.0.weight', | |
| 'model.1.sub.6.RDB3.conv5.0.bias', 'model.1.sub.7.RDB1.conv1.0.weight', | |
| 'model.1.sub.7.RDB1.conv1.0.bias', 'model.1.sub.7.RDB1.conv2.0.weight', | |
| 'model.1.sub.7.RDB1.conv2.0.bias', 'model.1.sub.7.RDB1.conv3.0.weight', | |
| 'model.1.sub.7.RDB1.conv3.0.bias', 'model.1.sub.7.RDB1.conv4.0.weight', | |
| 'model.1.sub.7.RDB1.conv4.0.bias', 'model.1.sub.7.RDB1.conv5.0.weight', | |
| 'model.1.sub.7.RDB1.conv5.0.bias', 'model.1.sub.7.RDB2.conv1.0.weight', | |
| 'model.1.sub.7.RDB2.conv1.0.bias', 'model.1.sub.7.RDB2.conv2.0.weight', | |
| 'model.1.sub.7.RDB2.conv2.0.bias', 'model.1.sub.7.RDB2.conv3.0.weight', | |
| 'model.1.sub.7.RDB2.conv3.0.bias', 'model.1.sub.7.RDB2.conv4.0.weight', | |
| 'model.1.sub.7.RDB2.conv4.0.bias', 'model.1.sub.7.RDB2.conv5.0.weight', | |
| 'model.1.sub.7.RDB2.conv5.0.bias', 'model.1.sub.7.RDB3.conv1.0.weight', | |
| 'model.1.sub.7.RDB3.conv1.0.bias', 'model.1.sub.7.RDB3.conv2.0.weight', | |
| 'model.1.sub.7.RDB3.conv2.0.bias', 'model.1.sub.7.RDB3.conv3.0.weight', | |
| 'model.1.sub.7.RDB3.conv3.0.bias', 'model.1.sub.7.RDB3.conv4.0.weight', | |
| 'model.1.sub.7.RDB3.conv4.0.bias', 'model.1.sub.7.RDB3.conv5.0.weight', | |
| 'model.1.sub.7.RDB3.conv5.0.bias', 'model.1.sub.8.RDB1.conv1.0.weight', | |
| 'model.1.sub.8.RDB1.conv1.0.bias', 'model.1.sub.8.RDB1.conv2.0.weight', | |
| 'model.1.sub.8.RDB1.conv2.0.bias', 'model.1.sub.8.RDB1.conv3.0.weight', | |
| 'model.1.sub.8.RDB1.conv3.0.bias', 'model.1.sub.8.RDB1.conv4.0.weight', | |
| 'model.1.sub.8.RDB1.conv4.0.bias', 'model.1.sub.8.RDB1.conv5.0.weight', | |
| 'model.1.sub.8.RDB1.conv5.0.bias', 'model.1.sub.8.RDB2.conv1.0.weight', | |
| 'model.1.sub.8.RDB2.conv1.0.bias', 'model.1.sub.8.RDB2.conv2.0.weight', | |
| 'model.1.sub.8.RDB2.conv2.0.bias', 'model.1.sub.8.RDB2.conv3.0.weight', | |
| 'model.1.sub.8.RDB2.conv3.0.bias', 'model.1.sub.8.RDB2.conv4.0.weight', | |
| 'model.1.sub.8.RDB2.conv4.0.bias', 'model.1.sub.8.RDB2.conv5.0.weight', | |
| 'model.1.sub.8.RDB2.conv5.0.bias', 'model.1.sub.8.RDB3.conv1.0.weight', | |
| 'model.1.sub.8.RDB3.conv1.0.bias', 'model.1.sub.8.RDB3.conv2.0.weight', | |
| 'model.1.sub.8.RDB3.conv2.0.bias', 'model.1.sub.8.RDB3.conv3.0.weight', | |
| 'model.1.sub.8.RDB3.conv3.0.bias', 'model.1.sub.8.RDB3.conv4.0.weight', | |
| 'model.1.sub.8.RDB3.conv4.0.bias', 'model.1.sub.8.RDB3.conv5.0.weight', | |
| 'model.1.sub.8.RDB3.conv5.0.bias', 'model.1.sub.9.RDB1.conv1.0.weight', | |
| 'model.1.sub.9.RDB1.conv1.0.bias', 'model.1.sub.9.RDB1.conv2.0.weight', | |
| 'model.1.sub.9.RDB1.conv2.0.bias', 'model.1.sub.9.RDB1.conv3.0.weight', | |
| 'model.1.sub.9.RDB1.conv3.0.bias', 'model.1.sub.9.RDB1.conv4.0.weight', | |
| 'model.1.sub.9.RDB1.conv4.0.bias', 'model.1.sub.9.RDB1.conv5.0.weight', | |
| 'model.1.sub.9.RDB1.conv5.0.bias', 'model.1.sub.9.RDB2.conv1.0.weight', | |
| 'model.1.sub.9.RDB2.conv1.0.bias', 'model.1.sub.9.RDB2.conv2.0.weight', | |
| 'model.1.sub.9.RDB2.conv2.0.bias', 'model.1.sub.9.RDB2.conv3.0.weight', | |
| 'model.1.sub.9.RDB2.conv3.0.bias', 'model.1.sub.9.RDB2.conv4.0.weight', | |
| 'model.1.sub.9.RDB2.conv4.0.bias', 'model.1.sub.9.RDB2.conv5.0.weight', | |
| 'model.1.sub.9.RDB2.conv5.0.bias', 'model.1.sub.9.RDB3.conv1.0.weight', | |
| 'model.1.sub.9.RDB3.conv1.0.bias', 'model.1.sub.9.RDB3.conv2.0.weight', | |
| 'model.1.sub.9.RDB3.conv2.0.bias', 'model.1.sub.9.RDB3.conv3.0.weight', | |
| 'model.1.sub.9.RDB3.conv3.0.bias', 'model.1.sub.9.RDB3.conv4.0.weight', | |
| 'model.1.sub.9.RDB3.conv4.0.bias', 'model.1.sub.9.RDB3.conv5.0.weight', | |
| 'model.1.sub.9.RDB3.conv5.0.bias', 'model.1.sub.10.RDB1.conv1.0.weight', | |
| 'model.1.sub.10.RDB1.conv1.0.bias', 'model.1.sub.10.RDB1.conv2.0.weight', | |
| 'model.1.sub.10.RDB1.conv2.0.bias', 'model.1.sub.10.RDB1.conv3.0.weight', | |
| 'model.1.sub.10.RDB1.conv3.0.bias', 'model.1.sub.10.RDB1.conv4.0.weight', | |
| 'model.1.sub.10.RDB1.conv4.0.bias', 'model.1.sub.10.RDB1.conv5.0.weight', | |
| 'model.1.sub.10.RDB1.conv5.0.bias', 'model.1.sub.10.RDB2.conv1.0.weight', | |
| 'model.1.sub.10.RDB2.conv1.0.bias', 'model.1.sub.10.RDB2.conv2.0.weight', | |
| 'model.1.sub.10.RDB2.conv2.0.bias', 'model.1.sub.10.RDB2.conv3.0.weight', | |
| 'model.1.sub.10.RDB2.conv3.0.bias', 'model.1.sub.10.RDB2.conv4.0.weight', | |
| 'model.1.sub.10.RDB2.conv4.0.bias', 'model.1.sub.10.RDB2.conv5.0.weight', | |
| 'model.1.sub.10.RDB2.conv5.0.bias', 'model.1.sub.10.RDB3.conv1.0.weight', | |
| 'model.1.sub.10.RDB3.conv1.0.bias', 'model.1.sub.10.RDB3.conv2.0.weight', | |
| 'model.1.sub.10.RDB3.conv2.0.bias', 'model.1.sub.10.RDB3.conv3.0.weight', | |
| 'model.1.sub.10.RDB3.conv3.0.bias', 'model.1.sub.10.RDB3.conv4.0.weight', | |
| 'model.1.sub.10.RDB3.conv4.0.bias', 'model.1.sub.10.RDB3.conv5.0.weight', | |
| 'model.1.sub.10.RDB3.conv5.0.bias', 'model.1.sub.11.RDB1.conv1.0.weight', | |
| 'model.1.sub.11.RDB1.conv1.0.bias', 'model.1.sub.11.RDB1.conv2.0.weight', | |
| 'model.1.sub.11.RDB1.conv2.0.bias', 'model.1.sub.11.RDB1.conv3.0.weight', | |
| 'model.1.sub.11.RDB1.conv3.0.bias', 'model.1.sub.11.RDB1.conv4.0.weight', | |
| 'model.1.sub.11.RDB1.conv4.0.bias', 'model.1.sub.11.RDB1.conv5.0.weight', | |
| 'model.1.sub.11.RDB1.conv5.0.bias', 'model.1.sub.11.RDB2.conv1.0.weight', | |
| 'model.1.sub.11.RDB2.conv1.0.bias', 'model.1.sub.11.RDB2.conv2.0.weight', | |
| 'model.1.sub.11.RDB2.conv2.0.bias', 'model.1.sub.11.RDB2.conv3.0.weight', | |
| 'model.1.sub.11.RDB2.conv3.0.bias', 'model.1.sub.11.RDB2.conv4.0.weight', | |
| 'model.1.sub.11.RDB2.conv4.0.bias', 'model.1.sub.11.RDB2.conv5.0.weight', | |
| 'model.1.sub.11.RDB2.conv5.0.bias', 'model.1.sub.11.RDB3.conv1.0.weight', | |
| 'model.1.sub.11.RDB3.conv1.0.bias', 'model.1.sub.11.RDB3.conv2.0.weight', | |
| 'model.1.sub.11.RDB3.conv2.0.bias', 'model.1.sub.11.RDB3.conv3.0.weight', | |
| 'model.1.sub.11.RDB3.conv3.0.bias', 'model.1.sub.11.RDB3.conv4.0.weight', | |
| 'model.1.sub.11.RDB3.conv4.0.bias', 'model.1.sub.11.RDB3.conv5.0.weight', | |
| 'model.1.sub.11.RDB3.conv5.0.bias', 'model.1.sub.12.RDB1.conv1.0.weight', | |
| 'model.1.sub.12.RDB1.conv1.0.bias', 'model.1.sub.12.RDB1.conv2.0.weight', | |
| 'model.1.sub.12.RDB1.conv2.0.bias', 'model.1.sub.12.RDB1.conv3.0.weight', | |
| 'model.1.sub.12.RDB1.conv3.0.bias', 'model.1.sub.12.RDB1.conv4.0.weight', | |
| 'model.1.sub.12.RDB1.conv4.0.bias', 'model.1.sub.12.RDB1.conv5.0.weight', | |
| 'model.1.sub.12.RDB1.conv5.0.bias', 'model.1.sub.12.RDB2.conv1.0.weight', | |
| 'model.1.sub.12.RDB2.conv1.0.bias', 'model.1.sub.12.RDB2.conv2.0.weight', | |
| 'model.1.sub.12.RDB2.conv2.0.bias', 'model.1.sub.12.RDB2.conv3.0.weight', | |
| 'model.1.sub.12.RDB2.conv3.0.bias', 'model.1.sub.12.RDB2.conv4.0.weight', | |
| 'model.1.sub.12.RDB2.conv4.0.bias', 'model.1.sub.12.RDB2.conv5.0.weight', | |
| 'model.1.sub.12.RDB2.conv5.0.bias', 'model.1.sub.12.RDB3.conv1.0.weight', | |
| 'model.1.sub.12.RDB3.conv1.0.bias', 'model.1.sub.12.RDB3.conv2.0.weight', | |
| 'model.1.sub.12.RDB3.conv2.0.bias', 'model.1.sub.12.RDB3.conv3.0.weight', | |
| 'model.1.sub.12.RDB3.conv3.0.bias', 'model.1.sub.12.RDB3.conv4.0.weight', | |
| 'model.1.sub.12.RDB3.conv4.0.bias', 'model.1.sub.12.RDB3.conv5.0.weight', | |
| 'model.1.sub.12.RDB3.conv5.0.bias', 'model.1.sub.13.RDB1.conv1.0.weight', | |
| 'model.1.sub.13.RDB1.conv1.0.bias', 'model.1.sub.13.RDB1.conv2.0.weight', | |
| 'model.1.sub.13.RDB1.conv2.0.bias', 'model.1.sub.13.RDB1.conv3.0.weight', | |
| 'model.1.sub.13.RDB1.conv3.0.bias', 'model.1.sub.13.RDB1.conv4.0.weight', | |
| 'model.1.sub.13.RDB1.conv4.0.bias', 'model.1.sub.13.RDB1.conv5.0.weight', | |
| 'model.1.sub.13.RDB1.conv5.0.bias', 'model.1.sub.13.RDB2.conv1.0.weight', | |
| 'model.1.sub.13.RDB2.conv1.0.bias', 'model.1.sub.13.RDB2.conv2.0.weight', | |
| 'model.1.sub.13.RDB2.conv2.0.bias', 'model.1.sub.13.RDB2.conv3.0.weight', | |
| 'model.1.sub.13.RDB2.conv3.0.bias', 'model.1.sub.13.RDB2.conv4.0.weight', | |
| 'model.1.sub.13.RDB2.conv4.0.bias', 'model.1.sub.13.RDB2.conv5.0.weight', | |
| 'model.1.sub.13.RDB2.conv5.0.bias', 'model.1.sub.13.RDB3.conv1.0.weight', | |
| 'model.1.sub.13.RDB3.conv1.0.bias', 'model.1.sub.13.RDB3.conv2.0.weight', | |
| 'model.1.sub.13.RDB3.conv2.0.bias', 'model.1.sub.13.RDB3.conv3.0.weight', | |
| 'model.1.sub.13.RDB3.conv3.0.bias', 'model.1.sub.13.RDB3.conv4.0.weight', | |
| 'model.1.sub.13.RDB3.conv4.0.bias', 'model.1.sub.13.RDB3.conv5.0.weight', | |
| 'model.1.sub.13.RDB3.conv5.0.bias', 'model.1.sub.14.RDB1.conv1.0.weight', | |
| 'model.1.sub.14.RDB1.conv1.0.bias', 'model.1.sub.14.RDB1.conv2.0.weight', | |
| 'model.1.sub.14.RDB1.conv2.0.bias', 'model.1.sub.14.RDB1.conv3.0.weight', | |
| 'model.1.sub.14.RDB1.conv3.0.bias', 'model.1.sub.14.RDB1.conv4.0.weight', | |
| 'model.1.sub.14.RDB1.conv4.0.bias', 'model.1.sub.14.RDB1.conv5.0.weight', | |
| 'model.1.sub.14.RDB1.conv5.0.bias', 'model.1.sub.14.RDB2.conv1.0.weight', | |
| 'model.1.sub.14.RDB2.conv1.0.bias', 'model.1.sub.14.RDB2.conv2.0.weight', | |
| 'model.1.sub.14.RDB2.conv2.0.bias', 'model.1.sub.14.RDB2.conv3.0.weight', | |
| 'model.1.sub.14.RDB2.conv3.0.bias', 'model.1.sub.14.RDB2.conv4.0.weight', | |
| 'model.1.sub.14.RDB2.conv4.0.bias', 'model.1.sub.14.RDB2.conv5.0.weight', | |
| 'model.1.sub.14.RDB2.conv5.0.bias', 'model.1.sub.14.RDB3.conv1.0.weight', | |
| 'model.1.sub.14.RDB3.conv1.0.bias', 'model.1.sub.14.RDB3.conv2.0.weight', | |
| 'model.1.sub.14.RDB3.conv2.0.bias', 'model.1.sub.14.RDB3.conv3.0.weight', | |
| 'model.1.sub.14.RDB3.conv3.0.bias', 'model.1.sub.14.RDB3.conv4.0.weight', | |
| 'model.1.sub.14.RDB3.conv4.0.bias', 'model.1.sub.14.RDB3.conv5.0.weight', | |
| 'model.1.sub.14.RDB3.conv5.0.bias', 'model.1.sub.15.RDB1.conv1.0.weight', | |
| 'model.1.sub.15.RDB1.conv1.0.bias', 'model.1.sub.15.RDB1.conv2.0.weight', | |
| 'model.1.sub.15.RDB1.conv2.0.bias', 'model.1.sub.15.RDB1.conv3.0.weight', | |
| 'model.1.sub.15.RDB1.conv3.0.bias', 'model.1.sub.15.RDB1.conv4.0.weight', | |
| 'model.1.sub.15.RDB1.conv4.0.bias', 'model.1.sub.15.RDB1.conv5.0.weight', | |
| 'model.1.sub.15.RDB1.conv5.0.bias', 'model.1.sub.15.RDB2.conv1.0.weight', | |
| 'model.1.sub.15.RDB2.conv1.0.bias', 'model.1.sub.15.RDB2.conv2.0.weight', | |
| 'model.1.sub.15.RDB2.conv2.0.bias', 'model.1.sub.15.RDB2.conv3.0.weight', | |
| 'model.1.sub.15.RDB2.conv3.0.bias', 'model.1.sub.15.RDB2.conv4.0.weight', | |
| 'model.1.sub.15.RDB2.conv4.0.bias', 'model.1.sub.15.RDB2.conv5.0.weight', | |
| 'model.1.sub.15.RDB2.conv5.0.bias', 'model.1.sub.15.RDB3.conv1.0.weight', | |
| 'model.1.sub.15.RDB3.conv1.0.bias', 'model.1.sub.15.RDB3.conv2.0.weight', | |
| 'model.1.sub.15.RDB3.conv2.0.bias', 'model.1.sub.15.RDB3.conv3.0.weight', | |
| 'model.1.sub.15.RDB3.conv3.0.bias', 'model.1.sub.15.RDB3.conv4.0.weight', | |
| 'model.1.sub.15.RDB3.conv4.0.bias', 'model.1.sub.15.RDB3.conv5.0.weight', | |
| 'model.1.sub.15.RDB3.conv5.0.bias', 'model.1.sub.16.RDB1.conv1.0.weight', | |
| 'model.1.sub.16.RDB1.conv1.0.bias', 'model.1.sub.16.RDB1.conv2.0.weight', | |
| 'model.1.sub.16.RDB1.conv2.0.bias', 'model.1.sub.16.RDB1.conv3.0.weight', | |
| 'model.1.sub.16.RDB1.conv3.0.bias', 'model.1.sub.16.RDB1.conv4.0.weight', | |
| 'model.1.sub.16.RDB1.conv4.0.bias', 'model.1.sub.16.RDB1.conv5.0.weight', | |
| 'model.1.sub.16.RDB1.conv5.0.bias', 'model.1.sub.16.RDB2.conv1.0.weight', | |
| 'model.1.sub.16.RDB2.conv1.0.bias', 'model.1.sub.16.RDB2.conv2.0.weight', | |
| 'model.1.sub.16.RDB2.conv2.0.bias', 'model.1.sub.16.RDB2.conv3.0.weight', | |
| 'model.1.sub.16.RDB2.conv3.0.bias', 'model.1.sub.16.RDB2.conv4.0.weight', | |
| 'model.1.sub.16.RDB2.conv4.0.bias', 'model.1.sub.16.RDB2.conv5.0.weight', | |
| 'model.1.sub.16.RDB2.conv5.0.bias', 'model.1.sub.16.RDB3.conv1.0.weight', | |
| 'model.1.sub.16.RDB3.conv1.0.bias', 'model.1.sub.16.RDB3.conv2.0.weight', | |
| 'model.1.sub.16.RDB3.conv2.0.bias', 'model.1.sub.16.RDB3.conv3.0.weight', | |
| 'model.1.sub.16.RDB3.conv3.0.bias', 'model.1.sub.16.RDB3.conv4.0.weight', | |
| 'model.1.sub.16.RDB3.conv4.0.bias', 'model.1.sub.16.RDB3.conv5.0.weight', | |
| 'model.1.sub.16.RDB3.conv5.0.bias', 'model.1.sub.17.RDB1.conv1.0.weight', | |
| 'model.1.sub.17.RDB1.conv1.0.bias', 'model.1.sub.17.RDB1.conv2.0.weight', | |
| 'model.1.sub.17.RDB1.conv2.0.bias', 'model.1.sub.17.RDB1.conv3.0.weight', | |
| 'model.1.sub.17.RDB1.conv3.0.bias', 'model.1.sub.17.RDB1.conv4.0.weight', | |
| 'model.1.sub.17.RDB1.conv4.0.bias', 'model.1.sub.17.RDB1.conv5.0.weight', | |
| 'model.1.sub.17.RDB1.conv5.0.bias', 'model.1.sub.17.RDB2.conv1.0.weight', | |
| 'model.1.sub.17.RDB2.conv1.0.bias', 'model.1.sub.17.RDB2.conv2.0.weight', | |
| 'model.1.sub.17.RDB2.conv2.0.bias', 'model.1.sub.17.RDB2.conv3.0.weight', | |
| 'model.1.sub.17.RDB2.conv3.0.bias', 'model.1.sub.17.RDB2.conv4.0.weight', | |
| 'model.1.sub.17.RDB2.conv4.0.bias', 'model.1.sub.17.RDB2.conv5.0.weight', | |
| 'model.1.sub.17.RDB2.conv5.0.bias', 'model.1.sub.17.RDB3.conv1.0.weight', | |
| 'model.1.sub.17.RDB3.conv1.0.bias', 'model.1.sub.17.RDB3.conv2.0.weight', | |
| 'model.1.sub.17.RDB3.conv2.0.bias', 'model.1.sub.17.RDB3.conv3.0.weight', | |
| 'model.1.sub.17.RDB3.conv3.0.bias', 'model.1.sub.17.RDB3.conv4.0.weight', | |
| 'model.1.sub.17.RDB3.conv4.0.bias', 'model.1.sub.17.RDB3.conv5.0.weight', | |
| 'model.1.sub.17.RDB3.conv5.0.bias', 'model.1.sub.18.RDB1.conv1.0.weight', | |
| 'model.1.sub.18.RDB1.conv1.0.bias', 'model.1.sub.18.RDB1.conv2.0.weight', | |
| 'model.1.sub.18.RDB1.conv2.0.bias', 'model.1.sub.18.RDB1.conv3.0.weight', | |
| 'model.1.sub.18.RDB1.conv3.0.bias', 'model.1.sub.18.RDB1.conv4.0.weight', | |
| 'model.1.sub.18.RDB1.conv4.0.bias', 'model.1.sub.18.RDB1.conv5.0.weight', | |
| 'model.1.sub.18.RDB1.conv5.0.bias', 'model.1.sub.18.RDB2.conv1.0.weight', | |
| 'model.1.sub.18.RDB2.conv1.0.bias', 'model.1.sub.18.RDB2.conv2.0.weight', | |
| 'model.1.sub.18.RDB2.conv2.0.bias', 'model.1.sub.18.RDB2.conv3.0.weight', | |
| 'model.1.sub.18.RDB2.conv3.0.bias', 'model.1.sub.18.RDB2.conv4.0.weight', | |
| 'model.1.sub.18.RDB2.conv4.0.bias', 'model.1.sub.18.RDB2.conv5.0.weight', | |
| 'model.1.sub.18.RDB2.conv5.0.bias', 'model.1.sub.18.RDB3.conv1.0.weight', | |
| 'model.1.sub.18.RDB3.conv1.0.bias', 'model.1.sub.18.RDB3.conv2.0.weight', | |
| 'model.1.sub.18.RDB3.conv2.0.bias', 'model.1.sub.18.RDB3.conv3.0.weight', | |
| 'model.1.sub.18.RDB3.conv3.0.bias', 'model.1.sub.18.RDB3.conv4.0.weight', | |
| 'model.1.sub.18.RDB3.conv4.0.bias', 'model.1.sub.18.RDB3.conv5.0.weight', | |
| 'model.1.sub.18.RDB3.conv5.0.bias', 'model.1.sub.19.RDB1.conv1.0.weight', | |
| 'model.1.sub.19.RDB1.conv1.0.bias', 'model.1.sub.19.RDB1.conv2.0.weight', | |
| 'model.1.sub.19.RDB1.conv2.0.bias', 'model.1.sub.19.RDB1.conv3.0.weight', | |
| 'model.1.sub.19.RDB1.conv3.0.bias', 'model.1.sub.19.RDB1.conv4.0.weight', | |
| 'model.1.sub.19.RDB1.conv4.0.bias', 'model.1.sub.19.RDB1.conv5.0.weight', | |
| 'model.1.sub.19.RDB1.conv5.0.bias', 'model.1.sub.19.RDB2.conv1.0.weight', | |
| 'model.1.sub.19.RDB2.conv1.0.bias', 'model.1.sub.19.RDB2.conv2.0.weight', | |
| 'model.1.sub.19.RDB2.conv2.0.bias', 'model.1.sub.19.RDB2.conv3.0.weight', | |
| 'model.1.sub.19.RDB2.conv3.0.bias', 'model.1.sub.19.RDB2.conv4.0.weight', | |
| 'model.1.sub.19.RDB2.conv4.0.bias', 'model.1.sub.19.RDB2.conv5.0.weight', | |
| 'model.1.sub.19.RDB2.conv5.0.bias', 'model.1.sub.19.RDB3.conv1.0.weight', | |
| 'model.1.sub.19.RDB3.conv1.0.bias', 'model.1.sub.19.RDB3.conv2.0.weight', | |
| 'model.1.sub.19.RDB3.conv2.0.bias', 'model.1.sub.19.RDB3.conv3.0.weight', | |
| 'model.1.sub.19.RDB3.conv3.0.bias', 'model.1.sub.19.RDB3.conv4.0.weight', | |
| 'model.1.sub.19.RDB3.conv4.0.bias', 'model.1.sub.19.RDB3.conv5.0.weight', | |
| 'model.1.sub.19.RDB3.conv5.0.bias', 'model.1.sub.20.RDB1.conv1.0.weight', | |
| 'model.1.sub.20.RDB1.conv1.0.bias', 'model.1.sub.20.RDB1.conv2.0.weight', | |
| 'model.1.sub.20.RDB1.conv2.0.bias', 'model.1.sub.20.RDB1.conv3.0.weight', | |
| 'model.1.sub.20.RDB1.conv3.0.bias', 'model.1.sub.20.RDB1.conv4.0.weight', | |
| 'model.1.sub.20.RDB1.conv4.0.bias', 'model.1.sub.20.RDB1.conv5.0.weight', | |
| 'model.1.sub.20.RDB1.conv5.0.bias', 'model.1.sub.20.RDB2.conv1.0.weight', | |
| 'model.1.sub.20.RDB2.conv1.0.bias', 'model.1.sub.20.RDB2.conv2.0.weight', | |
| 'model.1.sub.20.RDB2.conv2.0.bias', 'model.1.sub.20.RDB2.conv3.0.weight', | |
| 'model.1.sub.20.RDB2.conv3.0.bias', 'model.1.sub.20.RDB2.conv4.0.weight', | |
| 'model.1.sub.20.RDB2.conv4.0.bias', 'model.1.sub.20.RDB2.conv5.0.weight', | |
| 'model.1.sub.20.RDB2.conv5.0.bias', 'model.1.sub.20.RDB3.conv1.0.weight', | |
| 'model.1.sub.20.RDB3.conv1.0.bias', 'model.1.sub.20.RDB3.conv2.0.weight', | |
| 'model.1.sub.20.RDB3.conv2.0.bias', 'model.1.sub.20.RDB3.conv3.0.weight', | |
| 'model.1.sub.20.RDB3.conv3.0.bias', 'model.1.sub.20.RDB3.conv4.0.weight', | |
| 'model.1.sub.20.RDB3.conv4.0.bias', 'model.1.sub.20.RDB3.conv5.0.weight', | |
| 'model.1.sub.20.RDB3.conv5.0.bias', 'model.1.sub.21.RDB1.conv1.0.weight', | |
| 'model.1.sub.21.RDB1.conv1.0.bias', 'model.1.sub.21.RDB1.conv2.0.weight', | |
| 'model.1.sub.21.RDB1.conv2.0.bias', 'model.1.sub.21.RDB1.conv3.0.weight', | |
| 'model.1.sub.21.RDB1.conv3.0.bias', 'model.1.sub.21.RDB1.conv4.0.weight', | |
| 'model.1.sub.21.RDB1.conv4.0.bias', 'model.1.sub.21.RDB1.conv5.0.weight', | |
| 'model.1.sub.21.RDB1.conv5.0.bias', 'model.1.sub.21.RDB2.conv1.0.weight', | |
| 'model.1.sub.21.RDB2.conv1.0.bias', 'model.1.sub.21.RDB2.conv2.0.weight', | |
| 'model.1.sub.21.RDB2.conv2.0.bias', 'model.1.sub.21.RDB2.conv3.0.weight', | |
| 'model.1.sub.21.RDB2.conv3.0.bias', 'model.1.sub.21.RDB2.conv4.0.weight', | |
| 'model.1.sub.21.RDB2.conv4.0.bias', 'model.1.sub.21.RDB2.conv5.0.weight', | |
| 'model.1.sub.21.RDB2.conv5.0.bias', 'model.1.sub.21.RDB3.conv1.0.weight', | |
| 'model.1.sub.21.RDB3.conv1.0.bias', 'model.1.sub.21.RDB3.conv2.0.weight', | |
| 'model.1.sub.21.RDB3.conv2.0.bias', 'model.1.sub.21.RDB3.conv3.0.weight', | |
| 'model.1.sub.21.RDB3.conv3.0.bias', 'model.1.sub.21.RDB3.conv4.0.weight', | |
| 'model.1.sub.21.RDB3.conv4.0.bias', 'model.1.sub.21.RDB3.conv5.0.weight', | |
| 'model.1.sub.21.RDB3.conv5.0.bias', 'model.1.sub.22.RDB1.conv1.0.weight', | |
| 'model.1.sub.22.RDB1.conv1.0.bias', 'model.1.sub.22.RDB1.conv2.0.weight', | |
| 'model.1.sub.22.RDB1.conv2.0.bias', 'model.1.sub.22.RDB1.conv3.0.weight', | |
| 'model.1.sub.22.RDB1.conv3.0.bias', 'model.1.sub.22.RDB1.conv4.0.weight', | |
| 'model.1.sub.22.RDB1.conv4.0.bias', 'model.1.sub.22.RDB1.conv5.0.weight', | |
| 'model.1.sub.22.RDB1.conv5.0.bias', 'model.1.sub.22.RDB2.conv1.0.weight', | |
| 'model.1.sub.22.RDB2.conv1.0.bias', 'model.1.sub.22.RDB2.conv2.0.weight', | |
| 'model.1.sub.22.RDB2.conv2.0.bias', 'model.1.sub.22.RDB2.conv3.0.weight', | |
| 'model.1.sub.22.RDB2.conv3.0.bias', 'model.1.sub.22.RDB2.conv4.0.weight', | |
| 'model.1.sub.22.RDB2.conv4.0.bias', 'model.1.sub.22.RDB2.conv5.0.weight', | |
| 'model.1.sub.22.RDB2.conv5.0.bias', 'model.1.sub.22.RDB3.conv1.0.weight', | |
| 'model.1.sub.22.RDB3.conv1.0.bias', 'model.1.sub.22.RDB3.conv2.0.weight', | |
| 'model.1.sub.22.RDB3.conv2.0.bias', 'model.1.sub.22.RDB3.conv3.0.weight', | |
| 'model.1.sub.22.RDB3.conv3.0.bias', 'model.1.sub.22.RDB3.conv4.0.weight', | |
| 'model.1.sub.22.RDB3.conv4.0.bias', 'model.1.sub.22.RDB3.conv5.0.weight', | |
| 'model.1.sub.22.RDB3.conv5.0.bias', 'model.1.sub.23.weight', 'model.1.sub.23.bias', | |
| 'model.3.weight', 'model.3.bias', 'model.6.weight', 'model.6.bias', 'model.8.weight', | |
| 'model.8.bias', 'model.10.weight', 'model.10.bias'] | |
| # Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py | |
| # Which enhanced stuff that was already here | |
| class RRDBNet(nn.Module): | |
| def __init__( | |
| self, | |
| state_dict, | |
| norm=None, | |
| act: str = "leakyrelu", | |
| upsampler: str = "upconv", | |
| mode: B.ConvMode = "CNA", | |
| ) -> None: | |
| """ | |
| ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. | |
| By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, | |
| and Chen Change Loy. | |
| This is old-arch Residual in Residual Dense Block Network and is not | |
| the newest revision that's available at github.com/xinntao/ESRGAN. | |
| This is on purpose, the newest Network has severely limited the | |
| potential use of the Network with no benefits. | |
| This network supports model files from both new and old-arch. | |
| Args: | |
| norm: Normalization layer | |
| act: Activation layer | |
| upsampler: Upsample layer. upconv, pixel_shuffle | |
| mode: Convolution mode | |
| """ | |
| super(RRDBNet, self).__init__() | |
| self.model_arch = "ESRGAN" | |
| self.sub_type = "SR" | |
| self.state = state_dict | |
| self.norm = norm | |
| self.act = act | |
| self.upsampler = upsampler | |
| self.mode = mode | |
| self.state_map = { | |
| # currently supports old, new, and newer RRDBNet arch models | |
| # ESRGAN, BSRGAN/RealSR, Real-ESRGAN | |
| "model.0.weight": ("conv_first.weight",), | |
| "model.0.bias": ("conv_first.bias",), | |
| "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), | |
| "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), | |
| r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( | |
| r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", | |
| r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", | |
| ), | |
| } | |
| if "params_ema" in self.state: | |
| self.state = self.state["params_ema"] | |
| # self.model_arch = "RealESRGAN" | |
| self.num_blocks = self.get_num_blocks() | |
| self.plus = any("conv1x1" in k for k in self.state.keys()) | |
| if self.plus: | |
| self.model_arch = "ESRGAN+" | |
| self.state = self.new_to_old_arch(self.state) | |
| self.key_arr = list(self.state.keys()) | |
| self.in_nc: int = self.state[self.key_arr[0]].shape[1] | |
| self.out_nc: int = self.state[self.key_arr[-1]].shape[0] | |
| self.scale: int = self.get_scale() | |
| self.num_filters: int = self.state[self.key_arr[0]].shape[0] | |
| c2x2 = False | |
| if self.state["model.0.weight"].shape[-2] == 2: | |
| c2x2 = True | |
| self.scale = round(math.sqrt(self.scale / 4)) | |
| self.model_arch = "ESRGAN-2c2" | |
| self.supports_fp16 = True | |
| self.supports_bfp16 = True | |
| self.min_size_restriction = None | |
| # Detect if pixelunshuffle was used (Real-ESRGAN) | |
| if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( | |
| self.in_nc / 4, | |
| self.in_nc / 16, | |
| ): | |
| self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) | |
| else: | |
| self.shuffle_factor = None | |
| upsample_block = { | |
| "upconv": B.upconv_block, | |
| "pixel_shuffle": B.pixelshuffle_block, | |
| }.get(self.upsampler) | |
| if upsample_block is None: | |
| raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") | |
| if self.scale == 3: | |
| upsample_blocks = upsample_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.num_filters, | |
| upscale_factor=3, | |
| act_type=self.act, | |
| c2x2=c2x2, | |
| ) | |
| else: | |
| upsample_blocks = [ | |
| upsample_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.num_filters, | |
| act_type=self.act, | |
| c2x2=c2x2, | |
| ) | |
| for _ in range(int(math.log(self.scale, 2))) | |
| ] | |
| self.model = B.sequential( | |
| # fea conv | |
| B.conv_block( | |
| in_nc=self.in_nc, | |
| out_nc=self.num_filters, | |
| kernel_size=3, | |
| norm_type=None, | |
| act_type=None, | |
| c2x2=c2x2, | |
| ), | |
| B.ShortcutBlock( | |
| B.sequential( | |
| # rrdb blocks | |
| *[ | |
| B.RRDB( | |
| nf=self.num_filters, | |
| kernel_size=3, | |
| gc=32, | |
| stride=1, | |
| bias=True, | |
| pad_type="zero", | |
| norm_type=self.norm, | |
| act_type=self.act, | |
| mode="CNA", | |
| plus=self.plus, | |
| c2x2=c2x2, | |
| ) | |
| for _ in range(self.num_blocks) | |
| ], | |
| # lr conv | |
| B.conv_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.num_filters, | |
| kernel_size=3, | |
| norm_type=self.norm, | |
| act_type=None, | |
| mode=self.mode, | |
| c2x2=c2x2, | |
| ), | |
| ) | |
| ), | |
| *upsample_blocks, | |
| # hr_conv0 | |
| B.conv_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.num_filters, | |
| kernel_size=3, | |
| norm_type=None, | |
| act_type=self.act, | |
| c2x2=c2x2, | |
| ), | |
| # hr_conv1 | |
| B.conv_block( | |
| in_nc=self.num_filters, | |
| out_nc=self.out_nc, | |
| kernel_size=3, | |
| norm_type=None, | |
| act_type=None, | |
| c2x2=c2x2, | |
| ), | |
| ) | |
| # Adjust these properties for calculations outside of the model | |
| if self.shuffle_factor: | |
| self.in_nc //= self.shuffle_factor ** 2 | |
| self.scale //= self.shuffle_factor | |
| self.load_state_dict(self.state, strict=False) | |
| def new_to_old_arch(self, state): | |
| """Convert a new-arch model state dictionary to an old-arch dictionary.""" | |
| if "params_ema" in state: | |
| state = state["params_ema"] | |
| if "conv_first.weight" not in state: | |
| # model is already old arch, this is a loose check, but should be sufficient | |
| return state | |
| # add nb to state keys | |
| for kind in ("weight", "bias"): | |
| self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ | |
| f"model.1.sub./NB/.{kind}" | |
| ] | |
| del self.state_map[f"model.1.sub./NB/.{kind}"] | |
| old_state = OrderedDict() | |
| for old_key, new_keys in self.state_map.items(): | |
| for new_key in new_keys: | |
| if r"\1" in old_key: | |
| for k, v in state.items(): | |
| sub = re.sub(new_key, old_key, k) | |
| if sub != k: | |
| old_state[sub] = v | |
| else: | |
| if new_key in state: | |
| old_state[old_key] = state[new_key] | |
| # upconv layers | |
| max_upconv = 0 | |
| for key in state.keys(): | |
| match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) | |
| if match is not None: | |
| _, key_num, key_type = match.groups() | |
| old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] | |
| max_upconv = max(max_upconv, int(key_num) * 3) | |
| # final layers | |
| for key in state.keys(): | |
| if key in ("HRconv.weight", "conv_hr.weight"): | |
| old_state[f"model.{max_upconv + 2}.weight"] = state[key] | |
| elif key in ("HRconv.bias", "conv_hr.bias"): | |
| old_state[f"model.{max_upconv + 2}.bias"] = state[key] | |
| elif key in ("conv_last.weight",): | |
| old_state[f"model.{max_upconv + 4}.weight"] = state[key] | |
| elif key in ("conv_last.bias",): | |
| old_state[f"model.{max_upconv + 4}.bias"] = state[key] | |
| # Sort by first numeric value of each layer | |
| def compare(item1, item2): | |
| parts1 = item1.split(".") | |
| parts2 = item2.split(".") | |
| int1 = int(parts1[1]) | |
| int2 = int(parts2[1]) | |
| return int1 - int2 | |
| sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) | |
| # Rebuild the output dict in the right order | |
| out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) | |
| return out_dict | |
| def get_scale(self, min_part: int = 6) -> int: | |
| n = 0 | |
| for part in list(self.state): | |
| parts = part.split(".")[1:] | |
| if len(parts) == 2: | |
| part_num = int(parts[0]) | |
| if part_num > min_part and parts[1] == "weight": | |
| n += 1 | |
| return 2 ** n | |
| def get_num_blocks(self) -> int: | |
| nbs = [] | |
| state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( | |
| r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", | |
| ) | |
| for state_key in state_keys: | |
| for k in self.state: | |
| m = re.search(state_key, k) | |
| if m: | |
| nbs.append(int(m.group(1))) | |
| if nbs: | |
| break | |
| return max(*nbs) + 1 | |
| def forward(self, x): | |
| if self.shuffle_factor: | |
| _, _, h, w = x.size() | |
| mod_pad_h = ( | |
| self.shuffle_factor - h % self.shuffle_factor | |
| ) % self.shuffle_factor | |
| mod_pad_w = ( | |
| self.shuffle_factor - w % self.shuffle_factor | |
| ) % self.shuffle_factor | |
| x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") | |
| x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) | |
| x = self.model(x) | |
| return x[:, :, : h * self.scale, : w * self.scale] | |
| return self.model(x) | |