import torch import torchvision from models import basicblock as B def show_kv(net): for k, v in net.items(): print(k) # should run train debug mode first to get an initial model #crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth') # #for k, v in crt_net.items(): # print(k) #for k, v in crt_net.items(): # if k in pretrained_net: # crt_net[k] = pretrained_net[k] # print('replace ... ', k) # x2 -> x4 #crt_net['model.5.weight'] = pretrained_net['model.2.weight'] #crt_net['model.5.bias'] = pretrained_net['model.2.bias'] #crt_net['model.8.weight'] = pretrained_net['model.5.weight'] #crt_net['model.8.bias'] = pretrained_net['model.5.bias'] #crt_net['model.10.weight'] = pretrained_net['model.7.weight'] #crt_net['model.10.bias'] = pretrained_net['model.7.bias'] #torch.save(crt_net, '../pretrained_tmp.pth') # x2 -> x3 ''' in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3 new_filter = torch.Tensor(576, 64, 3, 3) new_filter[0:256, :, :, :] = in_filter new_filter[256:512, :, :, :] = in_filter new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :] crt_net['model.2.weight'] = new_filter in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3 new_bias = torch.Tensor(576) new_bias[0:256] = in_bias new_bias[256:512] = in_bias new_bias[512:] = in_bias[0:576 - 512] crt_net['model.2.bias'] = new_bias torch.save(crt_net, '../pretrained_tmp.pth') ''' # x2 -> x8 ''' crt_net['model.5.weight'] = pretrained_net['model.2.weight'] crt_net['model.5.bias'] = pretrained_net['model.2.bias'] crt_net['model.8.weight'] = pretrained_net['model.2.weight'] crt_net['model.8.bias'] = pretrained_net['model.2.bias'] crt_net['model.11.weight'] = pretrained_net['model.5.weight'] crt_net['model.11.bias'] = pretrained_net['model.5.bias'] crt_net['model.13.weight'] = pretrained_net['model.7.weight'] crt_net['model.13.bias'] = pretrained_net['model.7.bias'] torch.save(crt_net, '../pretrained_tmp.pth') ''' # x3/4/8 RGB -> Y def rgb2gray_net(net, only_input=True): if only_input: in_filter = net['0.weight'] in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114 in_new_filter.unsqueeze_(1) net['0.weight'] = in_new_filter # out_filter = pretrained_net['model.13.weight'] # out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \ # out_filter[2, :, :, :] * 0.114 # out_new_filter.unsqueeze_(0) # crt_net['model.13.weight'] = out_new_filter # out_bias = pretrained_net['model.13.bias'] # out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114 # out_new_bias = torch.Tensor(1).fill_(out_new_bias) # crt_net['model.13.bias'] = out_new_bias # torch.save(crt_net, '../pretrained_tmp.pth') return net if __name__ == '__main__': net = torchvision.models.vgg19(pretrained=True) for k,v in net.features.named_parameters(): if k=='0.weight': in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114 in_new_filter.unsqueeze_(1) v = in_new_filter print(v.shape) print(v[0,0,0,0]) if k=='0.bias': in_new_bias = v print(v[0]) print(net.features[0]) net.features[0] = B.conv(1, 64, mode='C') print(net.features[0]) net.features[0].weight.data=in_new_filter net.features[0].bias.data=in_new_bias for k,v in net.features.named_parameters(): if k=='0.weight': print(v[0,0,0,0]) if k=='0.bias': print(v[0]) # transfer parameters of old model to new one model_old = torch.load(model_path) state_dict = model.state_dict() for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()): state_dict[key2] = param print([key, key2]) # print([param.size(), param2.size()]) torch.save(state_dict, 'model_new.pth') # rgb2gray_net(net)