Spaces:
Running
Running
import torch.nn as nn | |
from .layers.PRM import Residual as ResidualPyramid | |
from .layers.Residual import Residual as Residual | |
from torch.autograd import Variable | |
import torch | |
from opt import opt | |
import math | |
class Hourglass(nn.Module): | |
def __init__(self, n, nFeats, nModules, inputResH, inputResW, net_type, B, C): | |
super(Hourglass, self).__init__() | |
self.ResidualUp = ResidualPyramid if n >= 2 else Residual | |
self.ResidualDown = ResidualPyramid if n >= 3 else Residual | |
self.depth = n | |
self.nModules = nModules | |
self.nFeats = nFeats | |
self.net_type = net_type | |
self.B = B | |
self.C = C | |
self.inputResH = inputResH | |
self.inputResW = inputResW | |
up1 = self._make_residual(self.ResidualUp, False, inputResH, inputResW) | |
low1 = nn.Sequential( | |
nn.MaxPool2d(2), | |
self._make_residual(self.ResidualDown, False, inputResH / 2, inputResW / 2) | |
) | |
if n > 1: | |
low2 = Hourglass(n - 1, nFeats, nModules, inputResH / 2, inputResW / 2, net_type, B, C) | |
else: | |
low2 = self._make_residual(self.ResidualDown, False, inputResH / 2, inputResW / 2) | |
low3 = self._make_residual(self.ResidualDown, True, inputResH / 2, inputResW / 2) | |
up2 = nn.UpsamplingNearest2d(scale_factor=2) | |
self.upperBranch = up1 | |
self.lowerBranch = nn.Sequential( | |
low1, | |
low2, | |
low3, | |
up2 | |
) | |
def _make_residual(self, resBlock, useConv, inputResH, inputResW): | |
layer_list = [] | |
for i in range(self.nModules): | |
layer_list.append(resBlock(self.nFeats, self.nFeats, inputResH, inputResW, | |
stride=1, net_type=self.net_type, useConv=useConv, | |
baseWidth=self.B, cardinality=self.C)) | |
return nn.Sequential(*layer_list) | |
def forward(self, x: Variable): | |
up1 = self.upperBranch(x) | |
up2 = self.lowerBranch(x) | |
# out = up1 + up2 | |
out = torch.add(up1, up2) | |
return out | |
class PyraNet(nn.Module): | |
def __init__(self): | |
super(PyraNet, self).__init__() | |
B, C = opt.baseWidth, opt.cardinality | |
self.inputResH = opt.inputResH / 4 | |
self.inputResW = opt.inputResW / 4 | |
self.nStack = opt.nStack | |
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) | |
if opt.init: | |
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / 3)) | |
cnv1 = nn.Sequential( | |
conv1, | |
nn.BatchNorm2d(64), | |
nn.ReLU(True) | |
) | |
r1 = nn.Sequential( | |
ResidualPyramid(64, 128, opt.inputResH / 2, opt.inputResW / 2, | |
stride=1, net_type='no_preact', useConv=False, baseWidth=B, cardinality=C), | |
nn.MaxPool2d(2) | |
) | |
r4 = ResidualPyramid(128, 128, self.inputResH, self.inputResW, | |
stride=1, net_type='preact', useConv=False, baseWidth=B, cardinality=C) | |
r5 = ResidualPyramid(128, opt.nFeats, self.inputResH, self.inputResW, | |
stride=1, net_type='preact', useConv=False, baseWidth=B, cardinality=C) | |
self.preact = nn.Sequential( | |
cnv1, | |
r1, | |
r4, | |
r5 | |
) | |
self.stack_lin = nn.ModuleList() | |
self.stack_out = nn.ModuleList() | |
self.stack_lin_ = nn.ModuleList() | |
self.stack_out_ = nn.ModuleList() | |
for i in range(self.nStack): | |
hg = Hourglass(4, opt.nFeats, opt.nResidual, self.inputResH, self.inputResW, 'preact', B, C) | |
conv1 = nn.Conv2d(opt.nFeats, opt.nFeats, kernel_size=1, stride=1, padding=0) | |
if opt.init: | |
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / 2)) | |
lin = nn.Sequential( | |
hg, | |
nn.BatchNorm2d(opt.nFeats), | |
nn.ReLU(True), | |
conv1, | |
nn.BatchNorm2d(opt.nFeats), | |
nn.ReLU(True) | |
) | |
tmpOut = nn.Conv2d(opt.nFeats, opt.nClasses, kernel_size=1, stride=1, padding=0) | |
if opt.init: | |
nn.init.xavier_normal(tmpOut.weight) | |
self.stack_lin.append(lin) | |
self.stack_out.append(tmpOut) | |
if i < self.nStack - 1: | |
lin_ = nn.Conv2d(opt.nFeats, opt.nFeats, kernel_size=1, stride=1, padding=0) | |
tmpOut_ = nn.Conv2d(opt.nClasses, opt.nFeats, kernel_size=1, stride=1, padding=0) | |
if opt.init: | |
nn.init.xavier_normal(lin_.weight) | |
nn.init.xavier_normal(tmpOut_.weight) | |
self.stack_lin_.append(lin_) | |
self.stack_out_.append(tmpOut_) | |
def forward(self, x: Variable): | |
out = [] | |
inter = self.preact(x) | |
for i in range(self.nStack): | |
lin = self.stack_lin[i](inter) | |
tmpOut = self.stack_out[i](lin) | |
out.append(tmpOut) | |
if i < self.nStack - 1: | |
lin_ = self.stack_lin_[i](lin) | |
tmpOut_ = self.stack_out_[i](tmpOut) | |
inter = inter + lin_ + tmpOut_ | |
return out | |
class PyraNet_Inference(nn.Module): | |
def __init__(self): | |
super(PyraNet_Inference, self).__init__() | |
B, C = opt.baseWidth, opt.cardinality | |
self.inputResH = opt.inputResH / 4 | |
self.inputResW = opt.inputResW / 4 | |
self.nStack = opt.nStack | |
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) | |
if opt.init: | |
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / 3)) | |
cnv1 = nn.Sequential( | |
conv1, | |
nn.BatchNorm2d(64), | |
nn.ReLU(True) | |
) | |
r1 = nn.Sequential( | |
ResidualPyramid(64, 128, opt.inputResH / 2, opt.inputResW / 2, | |
stride=1, net_type='no_preact', useConv=False, baseWidth=B, cardinality=C), | |
nn.MaxPool2d(2) | |
) | |
r4 = ResidualPyramid(128, 128, self.inputResH, self.inputResW, | |
stride=1, net_type='preact', useConv=False, baseWidth=B, cardinality=C) | |
r5 = ResidualPyramid(128, opt.nFeats, self.inputResH, self.inputResW, | |
stride=1, net_type='preact', useConv=False, baseWidth=B, cardinality=C) | |
self.preact = nn.Sequential( | |
cnv1, | |
r1, | |
r4, | |
r5 | |
) | |
self.stack_lin = nn.ModuleList() | |
self.stack_out = nn.ModuleList() | |
self.stack_lin_ = nn.ModuleList() | |
self.stack_out_ = nn.ModuleList() | |
for i in range(self.nStack): | |
hg = Hourglass(4, opt.nFeats, opt.nResidual, | |
self.inputResH, self.inputResW, 'preact', B, C) | |
conv1 = nn.Conv2d(opt.nFeats, opt.nFeats, | |
kernel_size=1, stride=1, padding=0) | |
if opt.init: | |
nn.init.xavier_normal(conv1.weight, gain=math.sqrt(1 / 2)) | |
lin = nn.Sequential( | |
hg, | |
nn.BatchNorm2d(opt.nFeats), | |
nn.ReLU(True), | |
conv1, | |
nn.BatchNorm2d(opt.nFeats), | |
nn.ReLU(True) | |
) | |
tmpOut = nn.Conv2d(opt.nFeats, opt.nClasses, | |
kernel_size=1, stride=1, padding=0) | |
if opt.init: | |
nn.init.xavier_normal(tmpOut.weight) | |
self.stack_lin.append(lin) | |
self.stack_out.append(tmpOut) | |
if i < self.nStack - 1: | |
lin_ = nn.Conv2d(opt.nFeats, opt.nFeats, | |
kernel_size=1, stride=1, padding=0) | |
tmpOut_ = nn.Conv2d(opt.nClasses, opt.nFeats, | |
kernel_size=1, stride=1, padding=0) | |
if opt.init: | |
nn.init.xavier_normal(lin_.weight) | |
nn.init.xavier_normal(tmpOut_.weight) | |
self.stack_lin_.append(lin_) | |
self.stack_out_.append(tmpOut_) | |
def forward(self, x: Variable): | |
inter = self.preact(x) | |
for i in range(self.nStack): | |
lin = self.stack_lin[i](inter) | |
tmpOut = self.stack_out[i](lin) | |
out = tmpOut | |
if i < self.nStack - 1: | |
lin_ = self.stack_lin_[i](lin) | |
tmpOut_ = self.stack_out_[i](tmpOut) | |
inter = inter + lin_ + tmpOut_ | |
return out | |
def createModel(**kw): | |
model = PyraNet() | |
return model | |
def createModel_Inference(**kw): | |
model = PyraNet_Inference() | |
return model | |