Spaces:
Running
Running
File size: 3,132 Bytes
2514fb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import torch.nn as nn
"""
# --------------------------------------------
# Batch Normalization
# --------------------------------------------
# Kai Zhang ([email protected])
# https://github.com/cszn
# 01/Jan/2019
# --------------------------------------------
"""
# --------------------------------------------
# remove/delete specified layer
# --------------------------------------------
def deleteLayer(model, layer_type=nn.BatchNorm2d):
''' Kai Zhang, 11/Jan/2019.
'''
for k, m in list(model.named_children()):
if isinstance(m, layer_type):
del model._modules[k]
deleteLayer(m, layer_type)
# --------------------------------------------
# merge bn, "conv+bn" --> "conv"
# --------------------------------------------
def merge_bn(model):
''' Kai Zhang, 11/Jan/2019.
merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv')
based on https://github.com/pytorch/pytorch/pull/901
'''
prev_m = None
for k, m in list(model.named_children()):
if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)):
w = prev_m.weight.data
if prev_m.bias is None:
zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type())
prev_m.bias = nn.Parameter(zeros)
b = prev_m.bias.data
invstd = m.running_var.clone().add_(m.eps).pow_(-0.5)
if isinstance(prev_m, nn.ConvTranspose2d):
w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w))
else:
w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
b.add_(-m.running_mean).mul_(invstd)
if m.affine:
if isinstance(prev_m, nn.ConvTranspose2d):
w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w))
else:
w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
b.mul_(m.weight.data).add_(m.bias.data)
del model._modules[k]
prev_m = m
merge_bn(m)
# --------------------------------------------
# add bn, "conv" --> "conv+bn"
# --------------------------------------------
def add_bn(model):
''' Kai Zhang, 11/Jan/2019.
'''
for k, m in list(model.named_children()):
if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)):
b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True)
b.weight.data.fill_(1)
new_m = nn.Sequential(model._modules[k], b)
model._modules[k] = new_m
add_bn(m)
# --------------------------------------------
# tidy model after removing bn
# --------------------------------------------
def tidy_sequential(model):
''' Kai Zhang, 11/Jan/2019.
'''
for k, m in list(model.named_children()):
if isinstance(m, nn.Sequential):
if m.__len__() == 1:
model._modules[k] = m.__getitem__(0)
tidy_sequential(m)
|