Sapphire-356's picture
add: Video2MC
95f8bbc
import torch.nn as nn
from .util_models import ConcatTable, CaddTable, Identity
import math
from opt import opt
class Residual(nn.Module):
def __init__(self, numIn, numOut, inputResH, inputResW, stride=1,
net_type='preact', useConv=False, baseWidth=9, cardinality=4):
super(Residual, self).__init__()
self.con = ConcatTable([convBlock(numIn, numOut, inputResH,
inputResW, net_type, baseWidth, cardinality, stride),
skipLayer(numIn, numOut, stride, useConv)])
self.cadd = CaddTable(True)
def forward(self, x):
out = self.con(x)
out = self.cadd(out)
return out
def convBlock(numIn, numOut, inputResH, inputResW, net_type, baseWidth, cardinality, stride):
numIn = int(numIn)
numOut = int(numOut)
addTable = ConcatTable()
s_list = []
if net_type != 'no_preact':
s_list.append(nn.BatchNorm2d(numIn))
s_list.append(nn.ReLU(True))
conv1 = nn.Conv2d(numIn, numOut // 2, kernel_size=1)
if opt.init:
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / 2))
s_list.append(conv1)
s_list.append(nn.BatchNorm2d(numOut // 2))
s_list.append(nn.ReLU(True))
conv2 = nn.Conv2d(numOut // 2, numOut // 2,
kernel_size=3, stride=stride, padding=1)
if opt.init:
nn.init.xavier_normal(conv2.weight)
s_list.append(conv2)
s = nn.Sequential(*s_list)
addTable.add(s)
D = math.floor(numOut // baseWidth)
C = cardinality
s_list = []
if net_type != 'no_preact':
s_list.append(nn.BatchNorm2d(numIn))
s_list.append(nn.ReLU(True))
conv1 = nn.Conv2d(numIn, D, kernel_size=1, stride=stride)
if opt.init:
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / C))
s_list.append(conv1)
s_list.append(nn.BatchNorm2d(D))
s_list.append(nn.ReLU(True))
s_list.append(pyramid(D, C, inputResH, inputResW))
s_list.append(nn.BatchNorm2d(D))
s_list.append(nn.ReLU(True))
a = nn.Conv2d(D, numOut // 2, kernel_size=1)
a.nBranchIn = C
if opt.init:
nn.init.xavier_normal(a.weight, gain=math.sqrt(1 / C))
s_list.append(a)
s = nn.Sequential(*s_list)
addTable.add(s)
elewiswAdd = nn.Sequential(
addTable,
CaddTable(False)
)
conv2 = nn.Conv2d(numOut // 2, numOut, kernel_size=1)
if opt.init:
nn.init.xavier_normal(conv2.weight, gain=math.sqrt(1 / 2))
model = nn.Sequential(
elewiswAdd,
nn.BatchNorm2d(numOut // 2),
nn.ReLU(True),
conv2
)
return model
def pyramid(D, C, inputResH, inputResW):
pyraTable = ConcatTable()
sc = math.pow(2, 1 / C)
for i in range(C):
scaled = 1 / math.pow(sc, i + 1)
conv1 = nn.Conv2d(D, D, kernel_size=3, stride=1, padding=1)
if opt.init:
nn.init.xavier_normal(conv1.weight)
s = nn.Sequential(
nn.FractionalMaxPool2d(2, output_ratio=(scaled, scaled)),
conv1,
nn.UpsamplingBilinear2d(size=(int(inputResH), int(inputResW))))
pyraTable.add(s)
pyra = nn.Sequential(
pyraTable,
CaddTable(False)
)
return pyra
class skipLayer(nn.Module):
def __init__(self, numIn, numOut, stride, useConv):
super(skipLayer, self).__init__()
self.identity = False
if numIn == numOut and stride == 1 and not useConv:
self.identity = True
else:
conv1 = nn.Conv2d(numIn, numOut, kernel_size=1, stride=stride)
if opt.init:
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / 2))
self.m = nn.Sequential(
nn.BatchNorm2d(numIn),
nn.ReLU(True),
conv1
)
def forward(self, x):
if self.identity:
return x
else:
return self.m(x)