Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import numpy as np | |
import torch | |
from collections import OrderedDict | |
# import scipy.io as io | |
import hdf5storage | |
""" | |
# -------------------------------------------- | |
# Convert matconvnet SimpleNN model into pytorch model | |
# -------------------------------------------- | |
# Kai Zhang ([email protected]) | |
# https://github.com/cszn | |
# 28/Nov/2019 | |
# -------------------------------------------- | |
""" | |
def weights2tensor(x, squeeze=False, in_features=None, out_features=None): | |
"""Modified version of https://github.com/albanie/pytorch-mcn | |
Adjust memory layout and load weights as torch tensor | |
Args: | |
x (ndaray): a numpy array, corresponding to a set of network weights | |
stored in column major order | |
squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove | |
singletons from the trailing dimensions. So after converting to | |
pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1) | |
it will be reshaped to a matrix with shape (A,B). | |
in_features (int :: None): used to reshape weights for a linear block. | |
out_features (int :: None): used to reshape weights for a linear block. | |
Returns: | |
torch.tensor: a permuted sets of weights, matching the pytorch layout | |
convention | |
""" | |
if x.ndim == 4: | |
x = x.transpose((3, 2, 0, 1)) | |
# for FFDNet, pixel-shuffle layer | |
# if x.shape[1]==13: | |
# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:] | |
# if x.shape[0]==12: | |
# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] | |
# if x.shape[1]==5: | |
# x=x[:,[0,2,1,3, 4],:,:] | |
# if x.shape[0]==4: | |
# x=x[[0,2,1,3],:,:,:] | |
## for SRMD, pixel-shuffle layer | |
# if x.shape[0]==12: | |
# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] | |
# if x.shape[0]==27: | |
# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:] | |
# if x.shape[0]==48: | |
# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:] | |
elif x.ndim == 3: # add by Kai | |
x = x[:,:,:,None] | |
x = x.transpose((3, 2, 0, 1)) | |
elif x.ndim == 2: | |
if x.shape[1] == 1: | |
x = x.flatten() | |
if squeeze: | |
if in_features and out_features: | |
x = x.reshape((out_features, in_features)) | |
x = np.squeeze(x) | |
return torch.from_numpy(np.ascontiguousarray(x)) | |
def save_model(network, save_path): | |
state_dict = network.state_dict() | |
for key, param in state_dict.items(): | |
state_dict[key] = param.cpu() | |
torch.save(state_dict, save_path) | |
if __name__ == '__main__': | |
# from utils import utils_logger | |
# import logging | |
# utils_logger.logger_info('a', 'a.log') | |
# logger = logging.getLogger('a') | |
# | |
# mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat') | |
mcn = hdf5storage.loadmat('models/modelcolor.mat') | |
#logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0]) | |
mat_net = OrderedDict() | |
for idx in range(25): | |
mat_net[str(idx)] = OrderedDict() | |
count = -1 | |
print(idx) | |
for i in range(13): | |
if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv': | |
count += 1 | |
w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0] | |
# print(w.shape) | |
w = weights2tensor(w) | |
# print(w.shape) | |
b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1] | |
b = weights2tensor(b) | |
print(b.shape) | |
mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w | |
mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b | |
torch.save(mat_net, 'model_zoo/modelcolor.pth') | |
# from models.network_dncnn import IRCNN as net | |
# network = net(in_nc=3, out_nc=3, nc=64) | |
# state_dict = network.state_dict() | |
# | |
# #show_kv(state_dict) | |
# | |
# for i in range(len(mcn['net'][0][0][0])): | |
# print(mcn['net'][0][0][0][i][0][0][0][0]) | |
# | |
# count = -1 | |
# mat_net = OrderedDict() | |
# for i in range(len(mcn['net'][0][0][0])): | |
# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv': | |
# | |
# count += 1 | |
# w = mcn['net'][0][0][0][i][0][1][0][0] | |
# print(w.shape) | |
# w = weights2tensor(w) | |
# print(w.shape) | |
# | |
# b = mcn['net'][0][0][0][i][0][1][0][1] | |
# b = weights2tensor(b) | |
# print(b.shape) | |
# | |
# mat_net['model.{:d}.weight'.format(count*2)] = w | |
# mat_net['model.{:d}.bias'.format(count*2)] = b | |
# | |
# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth') | |
# | |
# | |
# | |
# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth') | |
# def show_kv(net): | |
# for k, v in net.items(): | |
# print(k) | |
# | |
# show_kv(crt_net) | |
# from models.network_dncnn import DnCNN as net | |
# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') | |
# from models.network_srmd import SRMD as net | |
# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R') | |
# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle') | |
# | |
# from models.network_rrdb import RRDB as net | |
# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv') | |
# | |
# state_dict = network.state_dict() | |
# for key, param in state_dict.items(): | |
# print(key) | |
# from models.network_imdn import IMDN as net | |
# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') | |
# state_dict = network.state_dict() | |
# mat_net = OrderedDict() | |
# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()): | |
# mat_net[key] = param2 | |
# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth') | |
# | |
# net_old = torch.load('net_old.pth') | |
# def show_kv(net): | |
# for k, v in net.items(): | |
# print(k) | |
# | |
# show_kv(net_old) | |
# from models.network_dpsr import MSRResNet_prior as net | |
# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle') | |
# state_dict = network.state_dict() | |
# net_new = OrderedDict() | |
# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()): | |
# net_new[key] = param_old | |
# torch.save(net_new, 'net_new.pth') | |
# print(key) | |
# print(param.size()) | |
# run utils/utils_matconvnet.py | |