Spaces:
Running
Running
File size: 4,899 Bytes
95f8bbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch.nn as nn
from .layers.PRM import Residual as ResidualPyramid
from .layers.Residual import Residual as Residual
from torch.autograd import Variable
from opt import opt
from collections import defaultdict
class Hourglass(nn.Module):
def __init__(self, n, nFeats, nModules, inputResH, inputResW, net_type, B, C):
super(Hourglass, self).__init__()
self.ResidualUp = ResidualPyramid if n >= 2 else Residual
self.ResidualDown = ResidualPyramid if n >= 3 else Residual
self.depth = n
self.nModules = nModules
self.nFeats = nFeats
self.net_type = net_type
self.B = B
self.C = C
self.inputResH = inputResH
self.inputResW = inputResW
self.up1 = self._make_residual(self.ResidualUp, False, inputResH, inputResW)
self.low1 = nn.Sequential(
nn.MaxPool2d(2),
self._make_residual(self.ResidualDown, False, inputResH / 2, inputResW / 2)
)
if n > 1:
self.low2 = Hourglass(n - 1, nFeats, nModules, inputResH / 2, inputResW / 2, net_type, B, C)
else:
self.low2 = self._make_residual(self.ResidualDown, False, inputResH / 2, inputResW / 2)
self.low3 = self._make_residual(self.ResidualDown, True, inputResH / 2, inputResW / 2)
self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
self.upperBranch = self.up1
self.lowerBranch = nn.Sequential(
self.low1,
self.low2,
self.low3,
self.up2
)
def _make_residual(self, resBlock, useConv, inputResH, inputResW):
layer_list = []
for i in range(self.nModules):
layer_list.append(resBlock(self.nFeats, self.nFeats, inputResH, inputResW,
stride=1, net_type=self.net_type, useConv=useConv,
baseWidth=self.B, cardinality=self.C))
return nn.Sequential(*layer_list)
def forward(self, x: Variable):
up1 = self.upperBranch(x)
up2 = self.lowerBranch(x)
out = up1 + up2
return out
class PyraNet(nn.Module):
def __init__(self):
super(PyraNet, self).__init__()
B, C = opt.baseWidth, opt.cardinality
self.inputResH = opt.inputResH / 4
self.inputResW = opt.inputResW / 4
self.nStack = opt.nStack
self.cnv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(True)
)
self.r1 = nn.Sequential(
ResidualPyramid(64, 128, opt.inputResH / 2, opt.inputResW / 2,
stride=1, net_type='no_preact', useConv=False, baseWidth=B, cardinality=C),
nn.MaxPool2d(2)
)
self.r4 = ResidualPyramid(128, 128, self.inputResH, self.inputResW,
stride=1, net_type='preact', useConv=False, baseWidth=B, cardinality=C)
self.r5 = ResidualPyramid(128, opt.nFeats, self.inputResH, self.inputResW,
stride=1, net_type='preact', useConv=False, baseWidth=B, cardinality=C)
self.preact = nn.Sequential(
self.cnv1,
self.r1,
self.r4,
self.r5
)
self.stack_layers = defaultdict(list)
for i in range(self.nStack):
hg = Hourglass(4, opt.nFeats, opt.nResidual, self.inputResH, self.inputResW, 'preact', B, C)
lin = nn.Sequential(
hg,
nn.BatchNorm2d(opt.nFeats),
nn.ReLU(True),
nn.Conv2d(opt.nFeats, opt.nFeats, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(opt.nFeats),
nn.ReLU(True)
)
tmpOut = nn.Conv2d(opt.nFeats, opt.nClasses, kernel_size=1, stride=1, padding=0)
self.stack_layers['lin'].append(lin)
self.stack_layers['out'].append(tmpOut)
if i < self.nStack - 1:
lin_ = nn.Conv2d(opt.nFeats, opt.nFeats, kernel_size=1, stride=1, padding=0)
tmpOut_ = nn.Conv2d(opt.nClasses, opt.nFeats, kernel_size=1, stride=1, padding=0)
self.stack_layers['lin_'].append(lin_)
self.stack_layers['out_'].append(tmpOut_)
def forward(self, x: Variable):
out = []
inter = self.preact(x)
for i in range(self.nStack):
lin = self.stack_layers['lin'][i](inter)
tmpOut = self.stack_layers['out'][i](lin)
out.append(tmpOut)
if i < self.nStack - 1:
lin_ = self.stack_layers['lin_'][i](lin)
tmpOut_ = self.stack_layers['out_'][i](tmpOut)
inter = inter + lin_ + tmpOut_
return out
def createModel(**kw):
model = PyraNet()
return model
|