Spaces:
Build error
Build error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch as th | |
| import datetime | |
| import os | |
| import time | |
| import timeit | |
| import copy | |
| import numpy as np | |
| from torch.nn import ModuleList | |
| from torch.nn import Conv2d | |
| from torch.nn import LeakyReLU | |
| #PixelwiseNorm代替了BatchNorm | |
| class PixelwiseNorm(th.nn.Module): | |
| def __init__(self): | |
| super(PixelwiseNorm, self).__init__() | |
| def forward(self, x, alpha=1e-8): | |
| """ | |
| forward pass of the module | |
| :param x: input activations volume | |
| :param alpha: small number for numerical stability | |
| :return: y => pixel normalized activations | |
| """ | |
| y = x.pow(2.).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW] | |
| y = x / y # normalize the input x volume | |
| return y | |
| class MinibatchStdDev(th.nn.Module): | |
| """ | |
| Minibatch standard deviation layer for the discriminator | |
| """ | |
| def __init__(self): | |
| """ | |
| derived class constructor | |
| """ | |
| super().__init__() | |
| def forward(self, x, alpha=1e-8): | |
| """ | |
| forward pass of the layer | |
| :param x: input activation volume | |
| :param alpha: small number for numerical stability | |
| :return: y => x appended with standard deviation constant map | |
| """ | |
| batch_size, _, height, width = x.shape | |
| # [B x C x H x W] Subtract mean over batch. | |
| y = x - x.mean(dim=0, keepdim=True) | |
| # [1 x C x H x W] Calc standard deviation over batch | |
| y = th.sqrt(y.pow(2.).mean(dim=0, keepdim=False) + alpha) | |
| # [1] Take average over feature_maps and pixels. | |
| y = y.mean().view(1, 1, 1, 1) | |
| # [B x 1 x H x W] Replicate over group and pixels. | |
| y = y.repeat(batch_size, 1, height, width) | |
| # [B x C x H x W] Append as new feature_map. | |
| y = th.cat([x, y], 1) | |
| # return the computed values: | |
| return y | |
| # ========================================================== | |
| # Equalized learning rate blocks: | |
| # extending Conv2D and Deconv2D layers for equalized learning rate logic | |
| # ========================================================== | |
| class _equalized_conv2d(th.nn.Module): | |
| """ conv2d with the concept of equalized learning rate | |
| Args: | |
| :param c_in: input channels | |
| :param c_out: output channels | |
| :param k_size: kernel size (h, w) should be a tuple or a single integer | |
| :param stride: stride for conv | |
| :param pad: padding | |
| :param bias: whether to use bias or not | |
| """ | |
| def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True): | |
| """ constructor for the class """ | |
| from torch.nn.modules.utils import _pair | |
| from numpy import sqrt, prod | |
| super().__init__() | |
| # define the weight and bias if to be used | |
| self.weight = th.nn.Parameter(th.nn.init.normal_( | |
| th.empty(c_out, c_in, *_pair(k_size)) | |
| )) | |
| self.use_bias = bias | |
| self.stride = stride | |
| self.pad = pad | |
| if self.use_bias: | |
| self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) | |
| fan_in = prod(_pair(k_size)) * c_in # value of fan_in | |
| self.scale = sqrt(2) / sqrt(fan_in) | |
| def forward(self, x): | |
| """ | |
| forward pass of the network | |
| :param x: input | |
| :return: y => output | |
| """ | |
| from torch.nn.functional import conv2d | |
| return conv2d(input=x, | |
| weight=self.weight * self.scale, # scale the weight on runtime | |
| bias=self.bias if self.use_bias else None, | |
| stride=self.stride, | |
| padding=self.pad) | |
| def extra_repr(self): | |
| return ", ".join(map(str, self.weight.shape)) | |
| class _equalized_deconv2d(th.nn.Module): | |
| """ Transpose convolution using the equalized learning rate | |
| Args: | |
| :param c_in: input channels | |
| :param c_out: output channels | |
| :param k_size: kernel size | |
| :param stride: stride for convolution transpose | |
| :param pad: padding | |
| :param bias: whether to use bias or not | |
| """ | |
| def __init__(self, c_in, c_out, k_size, stride=1, pad=0, bias=True): | |
| """ constructor for the class """ | |
| from torch.nn.modules.utils import _pair | |
| from numpy import sqrt | |
| super().__init__() | |
| # define the weight and bias if to be used | |
| self.weight = th.nn.Parameter(th.nn.init.normal_( | |
| th.empty(c_in, c_out, *_pair(k_size)) | |
| )) | |
| self.use_bias = bias | |
| self.stride = stride | |
| self.pad = pad | |
| if self.use_bias: | |
| self.bias = th.nn.Parameter(th.FloatTensor(c_out).fill_(0)) | |
| fan_in = c_in # value of fan_in for deconv | |
| self.scale = sqrt(2) / sqrt(fan_in) | |
| def forward(self, x): | |
| """ | |
| forward pass of the layer | |
| :param x: input | |
| :return: y => output | |
| """ | |
| from torch.nn.functional import conv_transpose2d | |
| return conv_transpose2d(input=x, | |
| weight=self.weight * self.scale, # scale the weight on runtime | |
| bias=self.bias if self.use_bias else None, | |
| stride=self.stride, | |
| padding=self.pad) | |
| def extra_repr(self): | |
| return ", ".join(map(str, self.weight.shape)) | |
| #basic block of the encoding part of the genarater | |
| #编码器的基本卷积块 | |
| class conv_block(nn.Module): | |
| """ | |
| Convolution Block | |
| with two convolution layers | |
| """ | |
| def __init__(self, in_ch, out_ch,use_eql=True): | |
| super(conv_block, self).__init__() | |
| if use_eql: | |
| self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1), | |
| pad=0, bias=True) | |
| self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3), | |
| pad=1, bias=True) | |
| self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3), | |
| pad=1, bias=True) | |
| else: | |
| self.conv_1 = Conv2d(in_ch, out_ch, (3, 3), | |
| padding=1, bias=True) | |
| self.conv_2 = Conv2d(out_ch, out_ch, (3, 3), | |
| padding=1, bias=True) | |
| # pixel_wise feature normalizer: | |
| self.pixNorm = PixelwiseNorm() | |
| # leaky_relu: | |
| self.lrelu = LeakyReLU(0.2) | |
| def forward(self, x): | |
| """ | |
| forward pass of the block | |
| :param x: input | |
| :return: y => output | |
| """ | |
| from torch.nn.functional import interpolate | |
| #y = interpolate(x, scale_factor=2) | |
| y=self.conv_1(self.lrelu(self.pixNorm(x))) | |
| residual=y | |
| y=self.conv_2(self.lrelu(self.pixNorm(y))) | |
| y=self.conv_3(self.lrelu(self.pixNorm(y))) | |
| y=y+residual | |
| return y | |
| #basic up convolution block of the encoding part of the genarater | |
| #编码器的基本卷积块 | |
| class up_conv(nn.Module): | |
| """ | |
| Up Convolution Block | |
| """ | |
| def __init__(self, in_ch, out_ch,use_eql=True): | |
| super(up_conv, self).__init__() | |
| if use_eql: | |
| self.conv_1= _equalized_conv2d(in_ch, out_ch, (1, 1), | |
| pad=0, bias=True) | |
| self.conv_2 = _equalized_conv2d(out_ch, out_ch, (3, 3), | |
| pad=1, bias=True) | |
| self.conv_3 = _equalized_conv2d(out_ch, out_ch, (3, 3), | |
| pad=1, bias=True) | |
| else: | |
| self.conv_1 = Conv2d(in_ch, out_ch, (3, 3), | |
| padding=1, bias=True) | |
| self.conv_2 = Conv2d(out_ch, out_ch, (3, 3), | |
| padding=1, bias=True) | |
| # pixel_wise feature normalizer: | |
| self.pixNorm = PixelwiseNorm() | |
| # leaky_relu: | |
| self.lrelu = LeakyReLU(0.2) | |
| def forward(self, x): | |
| """ | |
| forward pass of the block | |
| :param x: input | |
| :return: y => output | |
| """ | |
| from torch.nn.functional import interpolate | |
| x = interpolate(x, scale_factor=2, mode="bilinear") | |
| y=self.conv_1(self.lrelu(self.pixNorm(x))) | |
| residual=y | |
| y=self.conv_2(self.lrelu(self.pixNorm(y))) | |
| y=self.conv_3(self.lrelu(self.pixNorm(y))) | |
| y=y+residual | |
| return y | |
| #判别器的最后一层 | |
| class DisFinalBlock(th.nn.Module): | |
| """ Final block for the Discriminator """ | |
| def __init__(self, in_channels, use_eql=True): | |
| """ | |
| constructor of the class | |
| :param in_channels: number of input channels | |
| :param use_eql: whether to use equalized learning rate | |
| """ | |
| from torch.nn import LeakyReLU | |
| from torch.nn import Conv2d | |
| super().__init__() | |
| # declare the required modules for forward pass | |
| self.batch_discriminator = MinibatchStdDev() | |
| if use_eql: | |
| self.conv_1 = _equalized_conv2d(in_channels + 1, in_channels, (3, 3), | |
| pad=1, bias=True) | |
| self.conv_2 = _equalized_conv2d(in_channels, in_channels, (4, 4),stride=2,pad=1, | |
| bias=True) | |
| # final layer emulates the fully connected layer | |
| self.conv_3 = _equalized_conv2d(in_channels, 1, (1, 1), bias=True) | |
| else: | |
| # modules required: | |
| self.conv_1 = Conv2d(in_channels + 1, in_channels, (3, 3), padding=1, bias=True) | |
| self.conv_2 = Conv2d(in_channels, in_channels, (4, 4), bias=True) | |
| # final conv layer emulates a fully connected layer | |
| self.conv_3 = Conv2d(in_channels, 1, (1, 1), bias=True) | |
| # leaky_relu: | |
| self.lrelu = LeakyReLU(0.2) | |
| def forward(self, x): | |
| """ | |
| forward pass of the FinalBlock | |
| :param x: input | |
| :return: y => output | |
| """ | |
| # minibatch_std_dev layer | |
| y = self.batch_discriminator(x) | |
| # define the computations | |
| y = self.lrelu(self.conv_1(y)) | |
| y = self.lrelu(self.conv_2(y)) | |
| # fully connected layer | |
| y = self.conv_3(y) # This layer has linear activation | |
| # flatten the output raw discriminator scores | |
| return y | |
| #判别器基本卷积块 | |
| class DisGeneralConvBlock(th.nn.Module): | |
| """ General block in the discriminator """ | |
| def __init__(self, in_channels, out_channels, use_eql=True): | |
| """ | |
| constructor of the class | |
| :param in_channels: number of input channels | |
| :param out_channels: number of output channels | |
| :param use_eql: whether to use equalized learning rate | |
| """ | |
| from torch.nn import AvgPool2d, LeakyReLU | |
| from torch.nn import Conv2d | |
| super().__init__() | |
| if use_eql: | |
| self.conv_1 = _equalized_conv2d(in_channels, in_channels, (3, 3), | |
| pad=1, bias=True) | |
| self.conv_2 = _equalized_conv2d(in_channels, out_channels, (3, 3), | |
| pad=1, bias=True) | |
| else: | |
| # convolutional modules | |
| self.conv_1 = Conv2d(in_channels, in_channels, (3, 3), | |
| padding=1, bias=True) | |
| self.conv_2 = Conv2d(in_channels, out_channels, (3, 3), | |
| padding=1, bias=True) | |
| self.downSampler = AvgPool2d(2) # downsampler | |
| # leaky_relu: | |
| self.lrelu = LeakyReLU(0.2) | |
| def forward(self, x): | |
| """ | |
| forward pass of the module | |
| :param x: input | |
| :return: y => output | |
| """ | |
| # define the computations | |
| y = self.lrelu(self.conv_1(x)) | |
| y = self.lrelu(self.conv_2(y)) | |
| y = self.downSampler(y) | |
| return y | |
| class from_rgb(nn.Module): | |
| """ | |
| The RGB image is transformed into a multi-channel feature map to be concatenated with | |
| the feature map with the same number of channels in the network | |
| 把RGB图转换为多通道特征图,以便与网络中相同通道数的特征图拼接 | |
| """ | |
| def __init__(self, outchannels, use_eql=True): | |
| super(from_rgb, self).__init__() | |
| if use_eql: | |
| self.conv_1 = _equalized_conv2d(3, outchannels, (1, 1), bias=True) | |
| else: | |
| self.conv_1 = nn.Conv2d(3, outchannels, (1, 1),bias=True) | |
| # pixel_wise feature normalizer: | |
| self.pixNorm = PixelwiseNorm() | |
| # leaky_relu: | |
| self.lrelu = LeakyReLU(0.2) | |
| def forward(self, x): | |
| """ | |
| forward pass of the block | |
| :param x: input | |
| :return: y => output | |
| """ | |
| y = self.pixNorm(self.lrelu(self.conv_1(x))) | |
| return y | |
| class to_rgb(nn.Module): | |
| """ | |
| 把多通道特征图转换为RGB三通道图,以便输入判别器 | |
| The multi-channel feature map is converted into RGB image for input to the discriminator | |
| """ | |
| def __init__(self, inchannels, use_eql=True): | |
| super(to_rgb, self).__init__() | |
| if use_eql: | |
| self.conv_1 = _equalized_conv2d(inchannels, 3, (1, 1), bias=True) | |
| else: | |
| self.conv_1 = nn.Conv2d(inchannels, 3, (1, 1),bias=True) | |
| def forward(self, x): | |
| """ | |
| forward pass of the block | |
| :param x: input | |
| :return: y => output | |
| """ | |
| y = self.conv_1(x) | |
| return y | |
| class Flatten(nn.Module): | |
| def forward(self, x): | |
| return x.view(x.size(0), -1) | |
| class CCA(nn.Module): | |
| """ | |
| CCA Block | |
| """ | |
| def __init__(self, F_g, F_x): | |
| super().__init__() | |
| self.mlp_x = nn.Sequential( | |
| Flatten(), | |
| nn.Linear(F_x, F_x)) | |
| self.mlp_g = nn.Sequential( | |
| Flatten(), | |
| nn.Linear(F_g, F_x)) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, g, x): | |
| # channel-wise attention | |
| avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) | |
| channel_att_x = self.mlp_x(avg_pool_x) | |
| avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3))) | |
| channel_att_g = self.mlp_g(avg_pool_g) | |
| channel_att_sum = (channel_att_x + channel_att_g)/2.0 | |
| scale = th.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) | |
| x_after_channel = x * scale | |
| out = self.relu(x_after_channel) | |
| return out |