Spaces:
Running
on
Zero
Running
on
Zero
| from contextlib import contextmanager | |
| from math import sqrt, log | |
| import torch | |
| import torch.nn as nn | |
| # import warnings | |
| # warnings.simplefilter('ignore') | |
| class BaseModule(nn.Module): | |
| def __init__(self): | |
| self.act_fn = None | |
| super(BaseModule, self).__init__() | |
| def selu_init_params(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d) and m.weight.requires_grad: | |
| m.weight.data.normal_(0.0, 1.0 / sqrt(m.weight.numel())) | |
| if m.bias is not None: | |
| m.bias.data.fill_(0) | |
| elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad: | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.Linear) and m.weight.requires_grad: | |
| m.weight.data.normal_(0, 1.0 / sqrt(m.weight.numel())) | |
| m.bias.data.zero_() | |
| def initialize_weights_xavier_uniform(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d) and m.weight.requires_grad: | |
| # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm2d) and m.weight.requires_grad: | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| def load_state_dict(self, state_dict, strict=True, self_state=False): | |
| own_state = self_state if self_state else self.state_dict() | |
| for name, param in state_dict.items(): | |
| if name in own_state: | |
| try: | |
| own_state[name].copy_(param.data) | |
| except Exception as e: | |
| print("Parameter {} fails to load.".format(name)) | |
| print("-----------------------------------------") | |
| print(e) | |
| else: | |
| print("Parameter {} is not in the model. ".format(name)) | |
| def set_activation_inplace(self): | |
| if hasattr(self, "act_fn") and hasattr(self.act_fn, "inplace"): | |
| # save memory | |
| self.act_fn.inplace = True | |
| yield | |
| self.act_fn.inplace = False | |
| else: | |
| yield | |
| def total_parameters(self): | |
| total = sum([i.numel() for i in self.parameters()]) | |
| trainable = sum([i.numel() for i in self.parameters() if i.requires_grad]) | |
| print( | |
| "Total parameters : {}. Trainable parameters : {}".format(total, trainable) | |
| ) | |
| return total | |
| def forward(self, *x): | |
| raise NotImplementedError | |
| class ResidualFixBlock(BaseModule): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| dilation=1, | |
| groups=1, | |
| activation=nn.SELU(), | |
| conv=nn.Conv2d, | |
| ): | |
| super(ResidualFixBlock, self).__init__() | |
| self.act_fn = activation | |
| self.m = nn.Sequential( | |
| conv( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| ), | |
| activation, | |
| # conv(out_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2, dilation=1, groups=groups), | |
| conv( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| ), | |
| ) | |
| def forward(self, x): | |
| out = self.m(x) | |
| return self.act_fn(out + x) | |
| class ConvBlock(BaseModule): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| dilation=1, | |
| groups=1, | |
| activation=nn.SELU(), | |
| conv=nn.Conv2d, | |
| ): | |
| super(ConvBlock, self).__init__() | |
| self.m = nn.Sequential( | |
| conv( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| ), | |
| activation, | |
| ) | |
| def forward(self, x): | |
| return self.m(x) | |
| class UpSampleBlock(BaseModule): | |
| def __init__(self, channels, scale, activation, atrous_rate=1, conv=nn.Conv2d): | |
| assert scale in [2, 4, 8], "Currently UpSampleBlock supports 2, 4, 8 scaling" | |
| super(UpSampleBlock, self).__init__() | |
| m = nn.Sequential( | |
| conv( | |
| channels, | |
| 4 * channels, | |
| kernel_size=3, | |
| padding=atrous_rate, | |
| dilation=atrous_rate, | |
| ), | |
| activation, | |
| nn.PixelShuffle(2), | |
| ) | |
| self.m = nn.Sequential(*[m for _ in range(int(log(scale, 2)))]) | |
| def forward(self, x): | |
| return self.m(x) | |
| class SpatialChannelSqueezeExcitation(BaseModule): | |
| # https://arxiv.org/abs/1709.01507 | |
| # https://arxiv.org/pdf/1803.02579v1.pdf | |
| def __init__(self, in_channel, reduction=16, activation=nn.ReLU()): | |
| super(SpatialChannelSqueezeExcitation, self).__init__() | |
| linear_nodes = max(in_channel // reduction, 4) # avoid only 1 node case | |
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.channel_excite = nn.Sequential( | |
| # check the paper for the number 16 in reduction. It is selected by experiment. | |
| nn.Linear(in_channel, linear_nodes), | |
| activation, | |
| nn.Linear(linear_nodes, in_channel), | |
| nn.Sigmoid(), | |
| ) | |
| self.spatial_excite = nn.Sequential( | |
| nn.Conv2d(in_channel, 1, kernel_size=1, stride=1, padding=0, bias=False), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| b, c, h, w = x.size() | |
| # | |
| channel = self.avg_pool(x).view(b, c) | |
| # channel = F.avg_pool2d(x, kernel_size=(h,w)).view(b,c) # used for porting to other frameworks | |
| cSE = self.channel_excite(channel).view(b, c, 1, 1) | |
| x_cSE = torch.mul(x, cSE) | |
| # spatial | |
| sSE = self.spatial_excite(x) | |
| x_sSE = torch.mul(x, sSE) | |
| # return x_sSE | |
| return torch.add(x_cSE, x_sSE) | |
| class PartialConv(nn.Module): | |
| # reference: | |
| # Image Inpainting for Irregular Holes Using Partial Convolutions | |
| # http://masc.cs.gmu.edu/wiki/partialconv/show?time=2018-05-24+21%3A41%3A10 | |
| # https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/net.py | |
| # https://github.com/SeitaroShinagawa/chainer-partial_convolution_image_inpainting/blob/master/common/net.py | |
| # partial based padding | |
| # https: // github.com / NVIDIA / partialconv / blob / master / models / pd_resnet.py | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias=True, | |
| ): | |
| super(PartialConv, self).__init__() | |
| self.feature_conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| dilation, | |
| groups, | |
| bias, | |
| ) | |
| self.mask_conv = nn.Conv2d( | |
| 1, 1, kernel_size, stride, padding, dilation, groups, bias=False | |
| ) | |
| self.window_size = self.mask_conv.kernel_size[0] * self.mask_conv.kernel_size[1] | |
| torch.nn.init.constant_(self.mask_conv.weight, 1.0) | |
| for param in self.mask_conv.parameters(): | |
| param.requires_grad = False | |
| def forward(self, x): | |
| output = self.feature_conv(x) | |
| if self.feature_conv.bias is not None: | |
| output_bias = self.feature_conv.bias.view(1, -1, 1, 1).expand_as(output) | |
| else: | |
| output_bias = torch.zeros_like(output, device=x.device) | |
| with torch.no_grad(): | |
| ones = torch.ones(1, 1, x.size(2), x.size(3), device=x.device) | |
| output_mask = self.mask_conv(ones) | |
| output_mask = self.window_size / output_mask | |
| output = (output - output_bias) * output_mask + output_bias | |
| return output | |