Spaces:
Running
Running
import torch.nn as nn | |
from .util_models import ConcatTable, CaddTable, Identity | |
import math | |
from opt import opt | |
class Residual(nn.Module): | |
def __init__(self, numIn, numOut, inputResH, inputResW, stride=1, | |
net_type='preact', useConv=False, baseWidth=9, cardinality=4): | |
super(Residual, self).__init__() | |
self.con = ConcatTable([convBlock(numIn, numOut, inputResH, | |
inputResW, net_type, baseWidth, cardinality, stride), | |
skipLayer(numIn, numOut, stride, useConv)]) | |
self.cadd = CaddTable(True) | |
def forward(self, x): | |
out = self.con(x) | |
out = self.cadd(out) | |
return out | |
def convBlock(numIn, numOut, inputResH, inputResW, net_type, baseWidth, cardinality, stride): | |
numIn = int(numIn) | |
numOut = int(numOut) | |
addTable = ConcatTable() | |
s_list = [] | |
if net_type != 'no_preact': | |
s_list.append(nn.BatchNorm2d(numIn)) | |
s_list.append(nn.ReLU(True)) | |
conv1 = nn.Conv2d(numIn, numOut // 2, kernel_size=1) | |
if opt.init: | |
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / 2)) | |
s_list.append(conv1) | |
s_list.append(nn.BatchNorm2d(numOut // 2)) | |
s_list.append(nn.ReLU(True)) | |
conv2 = nn.Conv2d(numOut // 2, numOut // 2, | |
kernel_size=3, stride=stride, padding=1) | |
if opt.init: | |
nn.init.xavier_normal(conv2.weight) | |
s_list.append(conv2) | |
s = nn.Sequential(*s_list) | |
addTable.add(s) | |
D = math.floor(numOut // baseWidth) | |
C = cardinality | |
s_list = [] | |
if net_type != 'no_preact': | |
s_list.append(nn.BatchNorm2d(numIn)) | |
s_list.append(nn.ReLU(True)) | |
conv1 = nn.Conv2d(numIn, D, kernel_size=1, stride=stride) | |
if opt.init: | |
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / C)) | |
s_list.append(conv1) | |
s_list.append(nn.BatchNorm2d(D)) | |
s_list.append(nn.ReLU(True)) | |
s_list.append(pyramid(D, C, inputResH, inputResW)) | |
s_list.append(nn.BatchNorm2d(D)) | |
s_list.append(nn.ReLU(True)) | |
a = nn.Conv2d(D, numOut // 2, kernel_size=1) | |
a.nBranchIn = C | |
if opt.init: | |
nn.init.xavier_normal(a.weight, gain=math.sqrt(1 / C)) | |
s_list.append(a) | |
s = nn.Sequential(*s_list) | |
addTable.add(s) | |
elewiswAdd = nn.Sequential( | |
addTable, | |
CaddTable(False) | |
) | |
conv2 = nn.Conv2d(numOut // 2, numOut, kernel_size=1) | |
if opt.init: | |
nn.init.xavier_normal(conv2.weight, gain=math.sqrt(1 / 2)) | |
model = nn.Sequential( | |
elewiswAdd, | |
nn.BatchNorm2d(numOut // 2), | |
nn.ReLU(True), | |
conv2 | |
) | |
return model | |
def pyramid(D, C, inputResH, inputResW): | |
pyraTable = ConcatTable() | |
sc = math.pow(2, 1 / C) | |
for i in range(C): | |
scaled = 1 / math.pow(sc, i + 1) | |
conv1 = nn.Conv2d(D, D, kernel_size=3, stride=1, padding=1) | |
if opt.init: | |
nn.init.xavier_normal(conv1.weight) | |
s = nn.Sequential( | |
nn.FractionalMaxPool2d(2, output_ratio=(scaled, scaled)), | |
conv1, | |
nn.UpsamplingBilinear2d(size=(int(inputResH), int(inputResW)))) | |
pyraTable.add(s) | |
pyra = nn.Sequential( | |
pyraTable, | |
CaddTable(False) | |
) | |
return pyra | |
class skipLayer(nn.Module): | |
def __init__(self, numIn, numOut, stride, useConv): | |
super(skipLayer, self).__init__() | |
self.identity = False | |
if numIn == numOut and stride == 1 and not useConv: | |
self.identity = True | |
else: | |
conv1 = nn.Conv2d(numIn, numOut, kernel_size=1, stride=stride) | |
if opt.init: | |
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / 2)) | |
self.m = nn.Sequential( | |
nn.BatchNorm2d(numIn), | |
nn.ReLU(True), | |
conv1 | |
) | |
def forward(self, x): | |
if self.identity: | |
return x | |
else: | |
return self.m(x) | |