Spaces:
Running
Running
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| ''' | |
| # -------------------------------------------- | |
| # Advanced nn.Sequential | |
| # https://github.com/xinntao/BasicSR | |
| # -------------------------------------------- | |
| ''' | |
| def sequential(*args): | |
| """Advanced nn.Sequential. | |
| Args: | |
| nn.Sequential, nn.Module | |
| Returns: | |
| nn.Sequential | |
| """ | |
| if len(args) == 1: | |
| if isinstance(args[0], OrderedDict): | |
| raise NotImplementedError('sequential does not support OrderedDict input.') | |
| return args[0] # No sequential is needed. | |
| modules = [] | |
| for module in args: | |
| if isinstance(module, nn.Sequential): | |
| for submodule in module.children(): | |
| modules.append(submodule) | |
| elif isinstance(module, nn.Module): | |
| modules.append(module) | |
| return nn.Sequential(*modules) | |
| ''' | |
| # -------------------------------------------- | |
| # Useful blocks | |
| # https://github.com/xinntao/BasicSR | |
| # -------------------------------- | |
| # conv + normaliation + relu (conv) | |
| # (PixelUnShuffle) | |
| # (ConditionalBatchNorm2d) | |
| # concat (ConcatBlock) | |
| # sum (ShortcutBlock) | |
| # resblock (ResBlock) | |
| # Channel Attention (CA) Layer (CALayer) | |
| # Residual Channel Attention Block (RCABlock) | |
| # Residual Channel Attention Group (RCAGroup) | |
| # Residual Dense Block (ResidualDenseBlock_5C) | |
| # Residual in Residual Dense Block (RRDB) | |
| # -------------------------------------------- | |
| ''' | |
| # -------------------------------------------- | |
| # return nn.Sequantial of (Conv + BN + ReLU) | |
| # -------------------------------------------- | |
| def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2): | |
| L = [] | |
| for t in mode: | |
| if t == 'C': | |
| L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) | |
| elif t == 'T': | |
| L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) | |
| elif t == 'B': | |
| L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) | |
| elif t == 'I': | |
| L.append(nn.InstanceNorm2d(out_channels, affine=True)) | |
| elif t == 'R': | |
| L.append(nn.ReLU(inplace=True)) | |
| elif t == 'r': | |
| L.append(nn.ReLU(inplace=False)) | |
| elif t == 'L': | |
| L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) | |
| elif t == 'l': | |
| L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) | |
| elif t == '2': | |
| L.append(nn.PixelShuffle(upscale_factor=2)) | |
| elif t == '3': | |
| L.append(nn.PixelShuffle(upscale_factor=3)) | |
| elif t == '4': | |
| L.append(nn.PixelShuffle(upscale_factor=4)) | |
| elif t == 'U': | |
| L.append(nn.Upsample(scale_factor=2, mode='nearest')) | |
| elif t == 'u': | |
| L.append(nn.Upsample(scale_factor=3, mode='nearest')) | |
| elif t == 'v': | |
| L.append(nn.Upsample(scale_factor=4, mode='nearest')) | |
| elif t == 'M': | |
| L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) | |
| elif t == 'A': | |
| L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) | |
| else: | |
| raise NotImplementedError('Undefined type: '.format(t)) | |
| return sequential(*L) | |
| # -------------------------------------------- | |
| # inverse of pixel_shuffle | |
| # -------------------------------------------- | |
| def pixel_unshuffle(input, upscale_factor): | |
| r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a | |
| tensor of shape :math:`(*, r^2C, H, W)`. | |
| Authors: | |
| Zhaoyi Yan, https://github.com/Zhaoyi-Yan | |
| Kai Zhang, https://github.com/cszn/FFDNet | |
| Date: | |
| 01/Jan/2019 | |
| """ | |
| batch_size, channels, in_height, in_width = input.size() | |
| out_height = in_height // upscale_factor | |
| out_width = in_width // upscale_factor | |
| input_view = input.contiguous().view( | |
| batch_size, channels, out_height, upscale_factor, | |
| out_width, upscale_factor) | |
| channels *= upscale_factor ** 2 | |
| unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() | |
| return unshuffle_out.view(batch_size, channels, out_height, out_width) | |
| class PixelUnShuffle(nn.Module): | |
| r"""Rearranges elements in a Tensor of shape :math:`(C, rH, rW)` to a | |
| tensor of shape :math:`(*, r^2C, H, W)`. | |
| Authors: | |
| Zhaoyi Yan, https://github.com/Zhaoyi-Yan | |
| Kai Zhang, https://github.com/cszn/FFDNet | |
| Date: | |
| 01/Jan/2019 | |
| """ | |
| def __init__(self, upscale_factor): | |
| super(PixelUnShuffle, self).__init__() | |
| self.upscale_factor = upscale_factor | |
| def forward(self, input): | |
| return pixel_unshuffle(input, self.upscale_factor) | |
| def extra_repr(self): | |
| return 'upscale_factor={}'.format(self.upscale_factor) | |
| # -------------------------------------------- | |
| # conditional batch norm | |
| # https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775 | |
| # -------------------------------------------- | |
| class ConditionalBatchNorm2d(nn.Module): | |
| def __init__(self, num_features, num_classes): | |
| super().__init__() | |
| self.num_features = num_features | |
| self.bn = nn.BatchNorm2d(num_features, affine=False) | |
| self.embed = nn.Embedding(num_classes, num_features * 2) | |
| self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) | |
| self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 | |
| def forward(self, x, y): | |
| out = self.bn(x) | |
| gamma, beta = self.embed(y).chunk(2, 1) | |
| out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) | |
| return out | |
| # -------------------------------------------- | |
| # Concat the output of a submodule to its input | |
| # -------------------------------------------- | |
| class ConcatBlock(nn.Module): | |
| def __init__(self, submodule): | |
| super(ConcatBlock, self).__init__() | |
| self.sub = submodule | |
| def forward(self, x): | |
| output = torch.cat((x, self.sub(x)), dim=1) | |
| return output | |
| def __repr__(self): | |
| return self.sub.__repr__() + 'concat' | |
| # -------------------------------------------- | |
| # sum the output of a submodule to its input | |
| # -------------------------------------------- | |
| class ShortcutBlock(nn.Module): | |
| def __init__(self, submodule): | |
| super(ShortcutBlock, self).__init__() | |
| self.sub = submodule | |
| def forward(self, x): | |
| output = x + self.sub(x) | |
| return output | |
| def __repr__(self): | |
| tmpstr = 'Identity + \n|' | |
| modstr = self.sub.__repr__().replace('\n', '\n|') | |
| tmpstr = tmpstr + modstr | |
| return tmpstr | |
| # -------------------------------------------- | |
| # Res Block: x + conv(relu(conv(x))) | |
| # -------------------------------------------- | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2): | |
| super(ResBlock, self).__init__() | |
| assert in_channels == out_channels, 'Only support in_channels==out_channels.' | |
| if mode[0] in ['R', 'L']: | |
| mode = mode[0].lower() + mode[1:] | |
| self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| def forward(self, x): | |
| res = self.res(x) | |
| return x + res | |
| # -------------------------------------------- | |
| # simplified information multi-distillation block (IMDB) | |
| # x + conv1(concat(split(relu(conv(x)))x3)) | |
| # -------------------------------------------- | |
| class IMDBlock(nn.Module): | |
| """ | |
| @inproceedings{hui2019lightweight, | |
| title={Lightweight Image Super-Resolution with Information Multi-distillation Network}, | |
| author={Hui, Zheng and Gao, Xinbo and Yang, Yunchu and Wang, Xiumei}, | |
| booktitle={Proceedings of the 27th ACM International Conference on Multimedia (ACM MM)}, | |
| pages={2024--2032}, | |
| year={2019} | |
| } | |
| @inproceedings{zhang2019aim, | |
| title={AIM 2019 Challenge on Constrained Super-Resolution: Methods and Results}, | |
| author={Kai Zhang and Shuhang Gu and Radu Timofte and others}, | |
| booktitle={IEEE International Conference on Computer Vision Workshops}, | |
| year={2019} | |
| } | |
| """ | |
| def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CL', d_rate=0.25, negative_slope=0.05): | |
| super(IMDBlock, self).__init__() | |
| self.d_nc = int(in_channels * d_rate) | |
| self.r_nc = int(in_channels - self.d_nc) | |
| assert mode[0] == 'C', 'convolutional layer first' | |
| self.conv1 = conv(in_channels, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.conv2 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.conv3 = conv(self.r_nc, in_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.conv4 = conv(self.r_nc, self.d_nc, kernel_size, stride, padding, bias, mode[0], negative_slope) | |
| self.conv1x1 = conv(self.d_nc*4, out_channels, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0], negative_slope=negative_slope) | |
| def forward(self, x): | |
| d1, r1 = torch.split(self.conv1(x), (self.d_nc, self.r_nc), dim=1) | |
| d2, r2 = torch.split(self.conv2(r1), (self.d_nc, self.r_nc), dim=1) | |
| d3, r3 = torch.split(self.conv3(r2), (self.d_nc, self.r_nc), dim=1) | |
| d4 = self.conv4(r3) | |
| res = self.conv1x1(torch.cat((d1, d2, d3, d4), dim=1)) | |
| return x + res | |
| # -------------------------------------------- | |
| # Enhanced Spatial Attention (ESA) | |
| # -------------------------------------------- | |
| class ESA(nn.Module): | |
| def __init__(self, channel=64, reduction=4, bias=True): | |
| super(ESA, self).__init__() | |
| # -->conv3x3(conv21)-----------------------------------------------------------------------------------------+ | |
| # conv1x1(conv1)-->conv3x3-2(conv2)-->maxpool7-3-->conv3x3(conv3)(relu)-->conv3x3(conv4)(relu)-->conv3x3(conv5)-->bilinear--->conv1x1(conv6)-->sigmoid | |
| self.r_nc = channel // reduction | |
| self.conv1 = nn.Conv2d(channel, self.r_nc, kernel_size=1) | |
| self.conv21 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=1) | |
| self.conv2 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=3, stride=2, padding=0) | |
| self.conv3 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=3, padding=1) | |
| self.conv4 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=3, padding=1) | |
| self.conv5 = nn.Conv2d(self.r_nc, self.r_nc, kernel_size=3, padding=1) | |
| self.conv6 = nn.Conv2d(self.r_nc, channel, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| x1 = self.conv1(x) | |
| x2 = F.max_pool2d(self.conv2(x1), kernel_size=7, stride=3) # 1/6 | |
| x2 = self.relu(self.conv3(x2)) | |
| x2 = self.relu(self.conv4(x2)) | |
| x2 = F.interpolate(self.conv5(x2), (x.size(2), x.size(3)), mode='bilinear', align_corners=False) | |
| x2 = self.conv6(x2 + self.conv21(x1)) | |
| return x.mul(self.sigmoid(x2)) | |
| # return x.mul_(self.sigmoid(x2)) | |
| class CFRB(nn.Module): | |
| def __init__(self, in_channels=50, out_channels=50, kernel_size=3, stride=1, padding=1, bias=True, mode='CL', d_rate=0.5, negative_slope=0.05): | |
| super(CFRB, self).__init__() | |
| self.d_nc = int(in_channels * d_rate) | |
| self.r_nc = in_channels # int(in_channels - self.d_nc) | |
| assert mode[0] == 'C', 'convolutional layer first' | |
| self.conv1_d = conv(in_channels, self.d_nc, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0]) | |
| self.conv1_r = conv(in_channels, self.r_nc, kernel_size, stride, padding, bias=bias, mode=mode[0]) | |
| self.conv2_d = conv(self.r_nc, self.d_nc, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0]) | |
| self.conv2_r = conv(self.r_nc, self.r_nc, kernel_size, stride, padding, bias=bias, mode=mode[0]) | |
| self.conv3_d = conv(self.r_nc, self.d_nc, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0]) | |
| self.conv3_r = conv(self.r_nc, self.r_nc, kernel_size, stride, padding, bias=bias, mode=mode[0]) | |
| self.conv4_d = conv(self.r_nc, self.d_nc, kernel_size, stride, padding, bias=bias, mode=mode[0]) | |
| self.conv1x1 = conv(self.d_nc*4, out_channels, kernel_size=1, stride=1, padding=0, bias=bias, mode=mode[0]) | |
| self.act = conv(mode=mode[-1], negative_slope=negative_slope) | |
| self.esa = ESA(in_channels, reduction=4, bias=True) | |
| def forward(self, x): | |
| d1 = self.conv1_d(x) | |
| x = self.act(self.conv1_r(x)+x) | |
| d2 = self.conv2_d(x) | |
| x = self.act(self.conv2_r(x)+x) | |
| d3 = self.conv3_d(x) | |
| x = self.act(self.conv3_r(x)+x) | |
| x = self.conv4_d(x) | |
| x = self.act(torch.cat([d1, d2, d3, x], dim=1)) | |
| x = self.esa(self.conv1x1(x)) | |
| return x | |
| # -------------------------------------------- | |
| # Channel Attention (CA) Layer | |
| # -------------------------------------------- | |
| class CALayer(nn.Module): | |
| def __init__(self, channel=64, reduction=16): | |
| super(CALayer, self).__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.conv_fc = nn.Sequential( | |
| nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| y = self.avg_pool(x) | |
| y = self.conv_fc(y) | |
| return x * y | |
| # -------------------------------------------- | |
| # Residual Channel Attention Block (RCAB) | |
| # -------------------------------------------- | |
| class RCABlock(nn.Module): | |
| def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', reduction=16, negative_slope=0.2): | |
| super(RCABlock, self).__init__() | |
| assert in_channels == out_channels, 'Only support in_channels==out_channels.' | |
| if mode[0] in ['R','L']: | |
| mode = mode[0].lower() + mode[1:] | |
| self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.ca = CALayer(out_channels, reduction) | |
| def forward(self, x): | |
| res = self.res(x) | |
| res = self.ca(res) | |
| return res + x | |
| # -------------------------------------------- | |
| # Residual Channel Attention Group (RG) | |
| # -------------------------------------------- | |
| class RCAGroup(nn.Module): | |
| def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', reduction=16, nb=12, negative_slope=0.2): | |
| super(RCAGroup, self).__init__() | |
| assert in_channels == out_channels, 'Only support in_channels==out_channels.' | |
| if mode[0] in ['R','L']: | |
| mode = mode[0].lower() + mode[1:] | |
| RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding, bias, mode, reduction, negative_slope) for _ in range(nb)] | |
| RG.append(conv(out_channels, out_channels, mode='C')) | |
| self.rg = nn.Sequential(*RG) # self.rg = ShortcutBlock(nn.Sequential(*RG)) | |
| def forward(self, x): | |
| res = self.rg(x) | |
| return res + x | |
| # -------------------------------------------- | |
| # Residual Dense Block | |
| # style: 5 convs | |
| # -------------------------------------------- | |
| class ResidualDenseBlock_5C(nn.Module): | |
| def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode='CR', negative_slope=0.2): | |
| super(ResidualDenseBlock_5C, self).__init__() | |
| # gc: growth channel | |
| self.conv1 = conv(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.conv2 = conv(nc+gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.conv3 = conv(nc+2*gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.conv4 = conv(nc+3*gc, gc, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.conv5 = conv(nc+4*gc, nc, kernel_size, stride, padding, bias, mode[:-1], negative_slope) | |
| def forward(self, x): | |
| x1 = self.conv1(x) | |
| x2 = self.conv2(torch.cat((x, x1), 1)) | |
| x3 = self.conv3(torch.cat((x, x1, x2), 1)) | |
| x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) | |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
| return x5.mul_(0.2) + x | |
| # -------------------------------------------- | |
| # Residual in Residual Dense Block | |
| # 3x5c | |
| # -------------------------------------------- | |
| class RRDB(nn.Module): | |
| def __init__(self, nc=64, gc=32, kernel_size=3, stride=1, padding=1, bias=True, mode='CR', negative_slope=0.2): | |
| super(RRDB, self).__init__() | |
| self.RDB1 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.RDB2 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) | |
| self.RDB3 = ResidualDenseBlock_5C(nc, gc, kernel_size, stride, padding, bias, mode, negative_slope) | |
| def forward(self, x): | |
| out = self.RDB1(x) | |
| out = self.RDB2(out) | |
| out = self.RDB3(out) | |
| return out.mul_(0.2) + x | |
| """ | |
| # -------------------------------------------- | |
| # Upsampler | |
| # Kai Zhang, https://github.com/cszn/KAIR | |
| # -------------------------------------------- | |
| # upsample_pixelshuffle | |
| # upsample_upconv | |
| # upsample_convtranspose | |
| # -------------------------------------------- | |
| """ | |
| # -------------------------------------------- | |
| # conv + subp (+ relu) | |
| # -------------------------------------------- | |
| def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' | |
| up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope) | |
| return up1 | |
| # -------------------------------------------- | |
| # nearest_upsample + conv (+ R) | |
| # -------------------------------------------- | |
| def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR' | |
| if mode[0] == '2': | |
| uc = 'UC' | |
| elif mode[0] == '3': | |
| uc = 'uC' | |
| elif mode[0] == '4': | |
| uc = 'vC' | |
| mode = mode.replace(mode[0], uc) | |
| up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope) | |
| return up1 | |
| # -------------------------------------------- | |
| # convTranspose (+ relu) | |
| # -------------------------------------------- | |
| def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' | |
| kernel_size = int(mode[0]) | |
| stride = int(mode[0]) | |
| mode = mode.replace(mode[0], 'T') | |
| up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| return up1 | |
| ''' | |
| # -------------------------------------------- | |
| # Downsampler | |
| # Kai Zhang, https://github.com/cszn/KAIR | |
| # -------------------------------------------- | |
| # downsample_strideconv | |
| # downsample_maxpool | |
| # downsample_avgpool | |
| # -------------------------------------------- | |
| ''' | |
| # -------------------------------------------- | |
| # strideconv (+ relu) | |
| # -------------------------------------------- | |
| def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' | |
| kernel_size = int(mode[0]) | |
| stride = int(mode[0]) | |
| mode = mode.replace(mode[0], 'C') | |
| down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| return down1 | |
| # -------------------------------------------- | |
| # maxpooling + conv (+ relu) | |
| # -------------------------------------------- | |
| def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' | |
| kernel_size_pool = int(mode[0]) | |
| stride_pool = int(mode[0]) | |
| mode = mode.replace(mode[0], 'MC') | |
| pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) | |
| pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) | |
| return sequential(pool, pool_tail) | |
| # -------------------------------------------- | |
| # averagepooling + conv (+ relu) | |
| # -------------------------------------------- | |
| def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' | |
| kernel_size_pool = int(mode[0]) | |
| stride_pool = int(mode[0]) | |
| mode = mode.replace(mode[0], 'AC') | |
| pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) | |
| pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) | |
| return sequential(pool, pool_tail) | |
| ''' | |
| # -------------------------------------------- | |
| # NonLocalBlock2D: | |
| # embedded_gaussian | |
| # +W(softmax(thetaXphi)Xg) | |
| # -------------------------------------------- | |
| ''' | |
| # -------------------------------------------- | |
| # non-local block with embedded_gaussian | |
| # https://github.com/AlexHex7/Non-local_pytorch | |
| # -------------------------------------------- | |
| class NonLocalBlock2D(nn.Module): | |
| def __init__(self, nc=64, kernel_size=1, stride=1, padding=0, bias=True, act_mode='B', downsample=False, downsample_mode='maxpool', negative_slope=0.2): | |
| super(NonLocalBlock2D, self).__init__() | |
| inter_nc = nc // 2 | |
| self.inter_nc = inter_nc | |
| self.W = conv(inter_nc, nc, kernel_size, stride, padding, bias, mode='C'+act_mode) | |
| self.theta = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C') | |
| if downsample: | |
| if downsample_mode == 'avgpool': | |
| downsample_block = downsample_avgpool | |
| elif downsample_mode == 'maxpool': | |
| downsample_block = downsample_maxpool | |
| elif downsample_mode == 'strideconv': | |
| downsample_block = downsample_strideconv | |
| else: | |
| raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) | |
| self.phi = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode='2') | |
| self.g = downsample_block(nc, inter_nc, kernel_size, stride, padding, bias, mode='2') | |
| else: | |
| self.phi = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C') | |
| self.g = conv(nc, inter_nc, kernel_size, stride, padding, bias, mode='C') | |
| def forward(self, x): | |
| ''' | |
| :param x: (b, c, t, h, w) | |
| :return: | |
| ''' | |
| batch_size = x.size(0) | |
| g_x = self.g(x).view(batch_size, self.inter_nc, -1) | |
| g_x = g_x.permute(0, 2, 1) | |
| theta_x = self.theta(x).view(batch_size, self.inter_nc, -1) | |
| theta_x = theta_x.permute(0, 2, 1) | |
| phi_x = self.phi(x).view(batch_size, self.inter_nc, -1) | |
| f = torch.matmul(theta_x, phi_x) | |
| f_div_C = F.softmax(f, dim=-1) | |
| y = torch.matmul(f_div_C, g_x) | |
| y = y.permute(0, 2, 1).contiguous() | |
| y = y.view(batch_size, self.inter_nc, *x.size()[2:]) | |
| W_y = self.W(y) | |
| z = W_y + x | |
| return z | |