# -*- coding: utf-8 -*- import numpy as np import torch from collections import OrderedDict # import scipy.io as io import hdf5storage """ # -------------------------------------------- # Convert matconvnet SimpleNN model into pytorch model # -------------------------------------------- # Kai Zhang (cskaizhang@gmail.com) # https://github.com/cszn # 28/Nov/2019 # -------------------------------------------- """ def weights2tensor(x, squeeze=False, in_features=None, out_features=None): """Modified version of https://github.com/albanie/pytorch-mcn Adjust memory layout and load weights as torch tensor Args: x (ndaray): a numpy array, corresponding to a set of network weights stored in column major order squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove singletons from the trailing dimensions. So after converting to pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1) it will be reshaped to a matrix with shape (A,B). in_features (int :: None): used to reshape weights for a linear block. out_features (int :: None): used to reshape weights for a linear block. Returns: torch.tensor: a permuted sets of weights, matching the pytorch layout convention """ if x.ndim == 4: x = x.transpose((3, 2, 0, 1)) # for FFDNet, pixel-shuffle layer # if x.shape[1]==13: # x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:] # if x.shape[0]==12: # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] # if x.shape[1]==5: # x=x[:,[0,2,1,3, 4],:,:] # if x.shape[0]==4: # x=x[[0,2,1,3],:,:,:] ## for SRMD, pixel-shuffle layer # if x.shape[0]==12: # x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] # if x.shape[0]==27: # x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:] # if x.shape[0]==48: # x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:] elif x.ndim == 3: # add by Kai x = x[:,:,:,None] x = x.transpose((3, 2, 0, 1)) elif x.ndim == 2: if x.shape[1] == 1: x = x.flatten() if squeeze: if in_features and out_features: x = x.reshape((out_features, in_features)) x = np.squeeze(x) return torch.from_numpy(np.ascontiguousarray(x)) def save_model(network, save_path): state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) if __name__ == '__main__': # from utils import utils_logger # import logging # utils_logger.logger_info('a', 'a.log') # logger = logging.getLogger('a') # # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat') mcn = hdf5storage.loadmat('models/modelcolor.mat') #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0]) mat_net = OrderedDict() for idx in range(25): mat_net[str(idx)] = OrderedDict() count = -1 print(idx) for i in range(13): if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv': count += 1 w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0] # print(w.shape) w = weights2tensor(w) # print(w.shape) b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1] b = weights2tensor(b) print(b.shape) mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b torch.save(mat_net, 'model_zoo/modelcolor.pth') # from models.network_dncnn import IRCNN as net # network = net(in_nc=3, out_nc=3, nc=64) # state_dict = network.state_dict() # # #show_kv(state_dict) # # for i in range(len(mcn['net'][0][0][0])): # print(mcn['net'][0][0][0][i][0][0][0][0]) # # count = -1 # mat_net = OrderedDict() # for i in range(len(mcn['net'][0][0][0])): # if mcn['net'][0][0][0][i][0][0][0][0] == 'conv': # # count += 1 # w = mcn['net'][0][0][0][i][0][1][0][0] # print(w.shape) # w = weights2tensor(w) # print(w.shape) # # b = mcn['net'][0][0][0][i][0][1][0][1] # b = weights2tensor(b) # print(b.shape) # # mat_net['model.{:d}.weight'.format(count*2)] = w # mat_net['model.{:d}.bias'.format(count*2)] = b # # torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth') # # # # crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth') # def show_kv(net): # for k, v in net.items(): # print(k) # # show_kv(crt_net) # from models.network_dncnn import DnCNN as net # network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') # from models.network_srmd import SRMD as net # #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R') # network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle') # # from models.network_rrdb import RRDB as net # network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv') # # state_dict = network.state_dict() # for key, param in state_dict.items(): # print(key) # from models.network_imdn import IMDN as net # network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') # state_dict = network.state_dict() # mat_net = OrderedDict() # for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()): # mat_net[key] = param2 # torch.save(mat_net, 'model_zoo/imdn_x4_1.pth') # # net_old = torch.load('net_old.pth') # def show_kv(net): # for k, v in net.items(): # print(k) # # show_kv(net_old) # from models.network_dpsr import MSRResNet_prior as net # model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle') # state_dict = network.state_dict() # net_new = OrderedDict() # for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()): # net_new[key] = param_old # torch.save(net_new, 'net_new.pth') # print(key) # print(param.size()) # run utils/utils_matconvnet.py