LambdaSuperRes / KAIR /utils /utils_params.py
cooperll
LambdaSuperRes initial commit
2514fb4
raw
history blame
4.04 kB
import torch
import torchvision
from models import basicblock as B
def show_kv(net):
for k, v in net.items():
print(k)
# should run train debug mode first to get an initial model
#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth')
#
#for k, v in crt_net.items():
# print(k)
#for k, v in crt_net.items():
# if k in pretrained_net:
# crt_net[k] = pretrained_net[k]
# print('replace ... ', k)
# x2 -> x4
#crt_net['model.5.weight'] = pretrained_net['model.2.weight']
#crt_net['model.5.bias'] = pretrained_net['model.2.bias']
#crt_net['model.8.weight'] = pretrained_net['model.5.weight']
#crt_net['model.8.bias'] = pretrained_net['model.5.bias']
#crt_net['model.10.weight'] = pretrained_net['model.7.weight']
#crt_net['model.10.bias'] = pretrained_net['model.7.bias']
#torch.save(crt_net, '../pretrained_tmp.pth')
# x2 -> x3
'''
in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3
new_filter = torch.Tensor(576, 64, 3, 3)
new_filter[0:256, :, :, :] = in_filter
new_filter[256:512, :, :, :] = in_filter
new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :]
crt_net['model.2.weight'] = new_filter
in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3
new_bias = torch.Tensor(576)
new_bias[0:256] = in_bias
new_bias[256:512] = in_bias
new_bias[512:] = in_bias[0:576 - 512]
crt_net['model.2.bias'] = new_bias
torch.save(crt_net, '../pretrained_tmp.pth')
'''
# x2 -> x8
'''
crt_net['model.5.weight'] = pretrained_net['model.2.weight']
crt_net['model.5.bias'] = pretrained_net['model.2.bias']
crt_net['model.8.weight'] = pretrained_net['model.2.weight']
crt_net['model.8.bias'] = pretrained_net['model.2.bias']
crt_net['model.11.weight'] = pretrained_net['model.5.weight']
crt_net['model.11.bias'] = pretrained_net['model.5.bias']
crt_net['model.13.weight'] = pretrained_net['model.7.weight']
crt_net['model.13.bias'] = pretrained_net['model.7.bias']
torch.save(crt_net, '../pretrained_tmp.pth')
'''
# x3/4/8 RGB -> Y
def rgb2gray_net(net, only_input=True):
if only_input:
in_filter = net['0.weight']
in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114
in_new_filter.unsqueeze_(1)
net['0.weight'] = in_new_filter
# out_filter = pretrained_net['model.13.weight']
# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \
# out_filter[2, :, :, :] * 0.114
# out_new_filter.unsqueeze_(0)
# crt_net['model.13.weight'] = out_new_filter
# out_bias = pretrained_net['model.13.bias']
# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114
# out_new_bias = torch.Tensor(1).fill_(out_new_bias)
# crt_net['model.13.bias'] = out_new_bias
# torch.save(crt_net, '../pretrained_tmp.pth')
return net
if __name__ == '__main__':
net = torchvision.models.vgg19(pretrained=True)
for k,v in net.features.named_parameters():
if k=='0.weight':
in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114
in_new_filter.unsqueeze_(1)
v = in_new_filter
print(v.shape)
print(v[0,0,0,0])
if k=='0.bias':
in_new_bias = v
print(v[0])
print(net.features[0])
net.features[0] = B.conv(1, 64, mode='C')
print(net.features[0])
net.features[0].weight.data=in_new_filter
net.features[0].bias.data=in_new_bias
for k,v in net.features.named_parameters():
if k=='0.weight':
print(v[0,0,0,0])
if k=='0.bias':
print(v[0])
# transfer parameters of old model to new one
model_old = torch.load(model_path)
state_dict = model.state_dict()
for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()):
state_dict[key2] = param
print([key, key2])
# print([param.size(), param2.size()])
torch.save(state_dict, 'model_new.pth')
# rgb2gray_net(net)