File size: 10,912 Bytes
09e181f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 |
from collections import namedtuple
from torch.nn import Dropout
from torch.nn import MaxPool2d
from torch.nn import Sequential
import torch
import torch.nn as nn
from torch.nn import Conv2d, Linear
from torch.nn import BatchNorm1d, BatchNorm2d
from torch.nn import ReLU, Sigmoid
from torch.nn import Module
from torch.nn import PReLU
from fvcore.nn import flop_count
import numpy as np
def initialize_weights(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
class LinearBlock(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(LinearBlock, self).__init__()
self.conv = Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False)
self.bn = BatchNorm2d(out_c)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(channels, channels // reduction,
kernel_size=1, padding=0, bias=False)
nn.init.xavier_uniform_(self.fc1.weight.data)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(channels // reduction, channels,
kernel_size=1, padding=0, bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class BasicBlockIR(Module):
def __init__(self, in_channel, depth, stride):
super(BasicBlockIR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class BottleneckIR(Module):
def __init__(self, in_channel, depth, stride):
super(BottleneckIR, self).__init__()
reduction_channel = depth // 4
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
BatchNorm2d(reduction_channel),
PReLU(reduction_channel),
Conv2d(reduction_channel, reduction_channel, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(reduction_channel),
PReLU(reduction_channel),
Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class BasicBlockIRSE(BasicBlockIR):
def __init__(self, in_channel, depth, stride):
super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module("se_block", SEModule(depth, 16))
class BottleneckIRSE(BottleneckIR):
def __init__(self, in_channel, depth, stride):
super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module("se_block", SEModule(depth, 16))
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
pass
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + \
[Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 18:
blocks = [
get_block(in_channel=64, depth=64, num_units=2),
get_block(in_channel=64, depth=128, num_units=2),
get_block(in_channel=128, depth=256, num_units=2),
get_block(in_channel=256, depth=512, num_units=2)
]
elif num_layers == 34:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=6),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=8),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
elif num_layers == 200:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=24),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
return blocks
class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir', flip=False, output_dim=512):
super(Backbone, self).__init__()
assert input_size[0] in [112, 224], \
"input_size should be [112, 112] or [224, 224]"
assert num_layers in [18, 34, 50, 100, 152, 200], \
"num_layers should be 18, 34, 50, 100 or 152"
assert mode in ['ir', 'ir_se'], \
"mode should be ir or ir_se"
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64), PReLU(64))
blocks = get_blocks(num_layers)
if num_layers <= 100:
if mode == 'ir':
unit_module = BasicBlockIR
elif mode == 'ir_se':
unit_module = BasicBlockIRSE
output_channel = 512
else:
if mode == 'ir':
unit_module = BottleneckIR
elif mode == 'ir_se':
unit_module = BottleneckIRSE
output_channel = 2048
if input_size[0] == 112:
self.output_layer = Sequential(BatchNorm2d(output_channel),
Dropout(0.4), Flatten(),
Linear(output_channel * 7 * 7, output_dim),
BatchNorm1d(output_dim, affine=False))
else:
self.output_layer = Sequential(
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
Linear(output_channel * 14 * 14, output_dim),
BatchNorm1d(output_dim, affine=False))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
initialize_weights(self.modules())
self.flip = flip
def forward(self, x):
if self.flip:
x = x.flip(1) # color channel flip
x = self.input_layer(x)
for idx, module in enumerate(self.body):
x = module(x)
x = self.output_layer(x)
return x
def IR_18(input_size, output_dim=512):
model = Backbone(input_size, 18, 'ir', output_dim=output_dim)
return model
def IR_34(input_size, output_dim=512):
model = Backbone(input_size, 34, 'ir', output_dim=output_dim)
return model
def IR_50(input_size, output_dim=512):
model = Backbone(input_size, 50, 'ir', output_dim=output_dim)
return model
def IR_101(input_size, output_dim=512):
model = Backbone(input_size, 100, 'ir', output_dim=output_dim)
return model
def IR_101_FLIP(input_size, output_dim=512):
model = Backbone(input_size, 100, 'ir', flip=True, output_dim=output_dim)
return model
def IR_152(input_size, output_dim=512):
model = Backbone(input_size, 152, 'ir', output_dim=output_dim)
return model
def IR_200(input_size, output_dim=512):
model = Backbone(input_size, 200, 'ir', output_dim=output_dim)
return model
def IR_SE_50(input_size, output_dim=512):
model = Backbone(input_size, 50, 'ir_se', output_dim=output_dim)
return model
def IR_SE_101(input_size, output_dim=512):
model = Backbone(input_size, 100, 'ir_se', output_dim=output_dim)
return model
def IR_SE_152(input_size, output_dim=512):
model = Backbone(input_size, 152, 'ir_se', output_dim=output_dim)
return model
def IR_SE_200(input_size, output_dim=512):
model = Backbone(input_size, 200, 'ir_se', output_dim=output_dim)
return model
if __name__ == '__main__':
inputs_shape = (1, 3, 112, 112)
model = IR_50(input_size=(112,112))
model.eval()
res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={})
fvcore_flop = np.array(list(res[0].values())).sum()
print('FLOPs: ', fvcore_flop / 1e9, 'G')
print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M')
|