|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
import torch.nn.functional as F |
|
|
|
|
|
def initialize_weights(net_l, scale=1): |
|
if not isinstance(net_l, list): |
|
net_l = [net_l] |
|
for net in net_l: |
|
for m in net.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in') |
|
m.weight.data *= scale |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Linear): |
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in') |
|
m.weight.data *= scale |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.constant_(m.weight, 1) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
|
|
def make_layer(block, n_layers): |
|
layers = [] |
|
for _ in range(n_layers): |
|
layers.append(block()) |
|
return nn.Sequential(*layers) |
|
|
|
|
|
class ResidualBlock_noBN(nn.Module): |
|
'''Residual block w/o BN |
|
---Conv-ReLU-Conv-+- |
|
|________________| |
|
''' |
|
|
|
def __init__(self, nf=64): |
|
super(ResidualBlock_noBN, self).__init__() |
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
|
|
|
initialize_weights([self.conv1, self.conv2], 0.1) |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = F.relu(self.conv1(x), inplace=True) |
|
out = self.conv2(out) |
|
return identity + out |
|
|
|
class ResidualBlock(nn.Module): |
|
'''Residual block w/o BN |
|
---Conv-ReLU-Conv-+- |
|
|________________| |
|
''' |
|
|
|
def __init__(self, nf=64): |
|
super(ResidualBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
self.bn = nn.BatchNorm2d(nf) |
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
|
|
|
initialize_weights([self.conv1, self.conv2], 0.1) |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = F.relu(self.bn(self.conv1(x)), inplace=True) |
|
out = self.conv2(out) |
|
return identity + out |
|
|
|
class LayerNormFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, x, weight, bias, eps): |
|
ctx.eps = eps |
|
N, C, H, W = x.size() |
|
mu = x.mean(1, keepdim=True) |
|
var = (x - mu).pow(2).mean(1, keepdim=True) |
|
y = (x - mu) / (var + eps).sqrt() |
|
ctx.save_for_backward(y, var, weight) |
|
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) |
|
return y |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
eps = ctx.eps |
|
|
|
N, C, H, W = grad_output.size() |
|
y, var, weight = ctx.saved_variables |
|
g = grad_output * weight.view(1, C, 1, 1) |
|
mean_g = g.mean(dim=1, keepdim=True) |
|
|
|
mean_gy = (g * y).mean(dim=1, keepdim=True) |
|
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) |
|
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( |
|
dim=0), None |
|
|
|
class LayerNorm2d(nn.Module): |
|
|
|
def __init__(self, channels, eps=1e-6): |
|
super(LayerNorm2d, self).__init__() |
|
self.register_parameter('weight', nn.Parameter(torch.ones(channels))) |
|
self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) |
|
self.eps = eps |
|
|
|
def forward(self, x): |
|
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) |
|
|