Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| import numpy as np | |
| import torch | |
| from collections import OrderedDict | |
| # import scipy.io as io | |
| import hdf5storage | |
| """ | |
| # -------------------------------------------- | |
| # Convert matconvnet SimpleNN model into pytorch model | |
| # -------------------------------------------- | |
| # Kai Zhang ([email protected]) | |
| # https://github.com/cszn | |
| # 28/Nov/2019 | |
| # -------------------------------------------- | |
| """ | |
| def weights2tensor(x, squeeze=False, in_features=None, out_features=None): | |
| """Modified version of https://github.com/albanie/pytorch-mcn | |
| Adjust memory layout and load weights as torch tensor | |
| Args: | |
| x (ndaray): a numpy array, corresponding to a set of network weights | |
| stored in column major order | |
| squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove | |
| singletons from the trailing dimensions. So after converting to | |
| pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1) | |
| it will be reshaped to a matrix with shape (A,B). | |
| in_features (int :: None): used to reshape weights for a linear block. | |
| out_features (int :: None): used to reshape weights for a linear block. | |
| Returns: | |
| torch.tensor: a permuted sets of weights, matching the pytorch layout | |
| convention | |
| """ | |
| if x.ndim == 4: | |
| x = x.transpose((3, 2, 0, 1)) | |
| # for FFDNet, pixel-shuffle layer | |
| # if x.shape[1]==13: | |
| # x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:] | |
| # if x.shape[0]==12: | |
| # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] | |
| # if x.shape[1]==5: | |
| # x=x[:,[0,2,1,3, 4],:,:] | |
| # if x.shape[0]==4: | |
| # x=x[[0,2,1,3],:,:,:] | |
| ## for SRMD, pixel-shuffle layer | |
| # if x.shape[0]==12: | |
| # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] | |
| # if x.shape[0]==27: | |
| # x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:] | |
| # if x.shape[0]==48: | |
| # x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:] | |
| elif x.ndim == 3: # add by Kai | |
| x = x[:,:,:,None] | |
| x = x.transpose((3, 2, 0, 1)) | |
| elif x.ndim == 2: | |
| if x.shape[1] == 1: | |
| x = x.flatten() | |
| if squeeze: | |
| if in_features and out_features: | |
| x = x.reshape((out_features, in_features)) | |
| x = np.squeeze(x) | |
| return torch.from_numpy(np.ascontiguousarray(x)) | |
| def save_model(network, save_path): | |
| state_dict = network.state_dict() | |
| for key, param in state_dict.items(): | |
| state_dict[key] = param.cpu() | |
| torch.save(state_dict, save_path) | |
| if __name__ == '__main__': | |
| # from utils import utils_logger | |
| # import logging | |
| # utils_logger.logger_info('a', 'a.log') | |
| # logger = logging.getLogger('a') | |
| # | |
| # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat') | |
| mcn = hdf5storage.loadmat('models/modelcolor.mat') | |
| #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0]) | |
| mat_net = OrderedDict() | |
| for idx in range(25): | |
| mat_net[str(idx)] = OrderedDict() | |
| count = -1 | |
| print(idx) | |
| for i in range(13): | |
| if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv': | |
| count += 1 | |
| w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0] | |
| # print(w.shape) | |
| w = weights2tensor(w) | |
| # print(w.shape) | |
| b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1] | |
| b = weights2tensor(b) | |
| print(b.shape) | |
| mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w | |
| mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b | |
| torch.save(mat_net, 'model_zoo/modelcolor.pth') | |
| # from models.network_dncnn import IRCNN as net | |
| # network = net(in_nc=3, out_nc=3, nc=64) | |
| # state_dict = network.state_dict() | |
| # | |
| # #show_kv(state_dict) | |
| # | |
| # for i in range(len(mcn['net'][0][0][0])): | |
| # print(mcn['net'][0][0][0][i][0][0][0][0]) | |
| # | |
| # count = -1 | |
| # mat_net = OrderedDict() | |
| # for i in range(len(mcn['net'][0][0][0])): | |
| # if mcn['net'][0][0][0][i][0][0][0][0] == 'conv': | |
| # | |
| # count += 1 | |
| # w = mcn['net'][0][0][0][i][0][1][0][0] | |
| # print(w.shape) | |
| # w = weights2tensor(w) | |
| # print(w.shape) | |
| # | |
| # b = mcn['net'][0][0][0][i][0][1][0][1] | |
| # b = weights2tensor(b) | |
| # print(b.shape) | |
| # | |
| # mat_net['model.{:d}.weight'.format(count*2)] = w | |
| # mat_net['model.{:d}.bias'.format(count*2)] = b | |
| # | |
| # torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth') | |
| # | |
| # | |
| # | |
| # crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth') | |
| # def show_kv(net): | |
| # for k, v in net.items(): | |
| # print(k) | |
| # | |
| # show_kv(crt_net) | |
| # from models.network_dncnn import DnCNN as net | |
| # network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') | |
| # from models.network_srmd import SRMD as net | |
| # #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R') | |
| # network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle') | |
| # | |
| # from models.network_rrdb import RRDB as net | |
| # network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv') | |
| # | |
| # state_dict = network.state_dict() | |
| # for key, param in state_dict.items(): | |
| # print(key) | |
| # from models.network_imdn import IMDN as net | |
| # network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') | |
| # state_dict = network.state_dict() | |
| # mat_net = OrderedDict() | |
| # for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()): | |
| # mat_net[key] = param2 | |
| # torch.save(mat_net, 'model_zoo/imdn_x4_1.pth') | |
| # | |
| # net_old = torch.load('net_old.pth') | |
| # def show_kv(net): | |
| # for k, v in net.items(): | |
| # print(k) | |
| # | |
| # show_kv(net_old) | |
| # from models.network_dpsr import MSRResNet_prior as net | |
| # model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle') | |
| # state_dict = network.state_dict() | |
| # net_new = OrderedDict() | |
| # for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()): | |
| # net_new[key] = param_old | |
| # torch.save(net_new, 'net_new.pth') | |
| # print(key) | |
| # print(param.size()) | |
| # run utils/utils_matconvnet.py | |