Spaces:
Running
Running
import torch | |
import torchvision | |
from models import basicblock as B | |
def show_kv(net): | |
for k, v in net.items(): | |
print(k) | |
# should run train debug mode first to get an initial model | |
#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth') | |
# | |
#for k, v in crt_net.items(): | |
# print(k) | |
#for k, v in crt_net.items(): | |
# if k in pretrained_net: | |
# crt_net[k] = pretrained_net[k] | |
# print('replace ... ', k) | |
# x2 -> x4 | |
#crt_net['model.5.weight'] = pretrained_net['model.2.weight'] | |
#crt_net['model.5.bias'] = pretrained_net['model.2.bias'] | |
#crt_net['model.8.weight'] = pretrained_net['model.5.weight'] | |
#crt_net['model.8.bias'] = pretrained_net['model.5.bias'] | |
#crt_net['model.10.weight'] = pretrained_net['model.7.weight'] | |
#crt_net['model.10.bias'] = pretrained_net['model.7.bias'] | |
#torch.save(crt_net, '../pretrained_tmp.pth') | |
# x2 -> x3 | |
''' | |
in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3 | |
new_filter = torch.Tensor(576, 64, 3, 3) | |
new_filter[0:256, :, :, :] = in_filter | |
new_filter[256:512, :, :, :] = in_filter | |
new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :] | |
crt_net['model.2.weight'] = new_filter | |
in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3 | |
new_bias = torch.Tensor(576) | |
new_bias[0:256] = in_bias | |
new_bias[256:512] = in_bias | |
new_bias[512:] = in_bias[0:576 - 512] | |
crt_net['model.2.bias'] = new_bias | |
torch.save(crt_net, '../pretrained_tmp.pth') | |
''' | |
# x2 -> x8 | |
''' | |
crt_net['model.5.weight'] = pretrained_net['model.2.weight'] | |
crt_net['model.5.bias'] = pretrained_net['model.2.bias'] | |
crt_net['model.8.weight'] = pretrained_net['model.2.weight'] | |
crt_net['model.8.bias'] = pretrained_net['model.2.bias'] | |
crt_net['model.11.weight'] = pretrained_net['model.5.weight'] | |
crt_net['model.11.bias'] = pretrained_net['model.5.bias'] | |
crt_net['model.13.weight'] = pretrained_net['model.7.weight'] | |
crt_net['model.13.bias'] = pretrained_net['model.7.bias'] | |
torch.save(crt_net, '../pretrained_tmp.pth') | |
''' | |
# x3/4/8 RGB -> Y | |
def rgb2gray_net(net, only_input=True): | |
if only_input: | |
in_filter = net['0.weight'] | |
in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114 | |
in_new_filter.unsqueeze_(1) | |
net['0.weight'] = in_new_filter | |
# out_filter = pretrained_net['model.13.weight'] | |
# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \ | |
# out_filter[2, :, :, :] * 0.114 | |
# out_new_filter.unsqueeze_(0) | |
# crt_net['model.13.weight'] = out_new_filter | |
# out_bias = pretrained_net['model.13.bias'] | |
# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114 | |
# out_new_bias = torch.Tensor(1).fill_(out_new_bias) | |
# crt_net['model.13.bias'] = out_new_bias | |
# torch.save(crt_net, '../pretrained_tmp.pth') | |
return net | |
if __name__ == '__main__': | |
net = torchvision.models.vgg19(pretrained=True) | |
for k,v in net.features.named_parameters(): | |
if k=='0.weight': | |
in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114 | |
in_new_filter.unsqueeze_(1) | |
v = in_new_filter | |
print(v.shape) | |
print(v[0,0,0,0]) | |
if k=='0.bias': | |
in_new_bias = v | |
print(v[0]) | |
print(net.features[0]) | |
net.features[0] = B.conv(1, 64, mode='C') | |
print(net.features[0]) | |
net.features[0].weight.data=in_new_filter | |
net.features[0].bias.data=in_new_bias | |
for k,v in net.features.named_parameters(): | |
if k=='0.weight': | |
print(v[0,0,0,0]) | |
if k=='0.bias': | |
print(v[0]) | |
# transfer parameters of old model to new one | |
model_old = torch.load(model_path) | |
state_dict = model.state_dict() | |
for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()): | |
state_dict[key2] = param | |
print([key, key2]) | |
# print([param.size(), param2.size()]) | |
torch.save(state_dict, 'model_new.pth') | |
# rgb2gray_net(net) | |