import torch.nn as nn import models.basicblock as B import torch """ # -------------------------------------------- # SRMD (15 conv layers) # -------------------------------------------- Reference: @inproceedings{zhang2018learning, title={Learning a single convolutional super-resolution network for multiple degradations}, author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, pages={3262--3271}, year={2018} } http://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Learning_a_Single_CVPR_2018_paper.pdf """ # -------------------------------------------- # SRMD (SRMD, in_nc = 3+15+1 = 19) # SRMD (SRMDNF, in_nc = 3+15 = 18) # -------------------------------------------- class SRMD(nn.Module): def __init__(self, in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle'): """ # ------------------------------------ in_nc: channel number of input, default: 3+15 out_nc: channel number of output nc: channel number nb: total number of conv layers upscale: scale factor act_mode: batch norm + activation function; 'BR' means BN+ReLU upsample_mode: default 'pixelshuffle' = conv + pixelshuffle # ------------------------------------ """ super(SRMD, self).__init__() assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL' bias = True if upsample_mode == 'upconv': upsample_block = B.upsample_upconv elif upsample_mode == 'pixelshuffle': upsample_block = B.upsample_pixelshuffle elif upsample_mode == 'convtranspose': upsample_block = B.upsample_convtranspose else: raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias) m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)] m_tail = upsample_block(nc, out_nc, mode=str(upscale), bias=bias) self.model = B.sequential(m_head, *m_body, m_tail) # def forward(self, x, k_pca): # m = k_pca.repeat(1, 1, x.size()[-2], x.size()[-1]) # x = torch.cat((x, m), 1) # x = self.body(x) def forward(self, x): x = self.model(x) return x if __name__ == '__main__': from utils import utils_model model = SRMD(in_nc=18, out_nc=3, nc=64, nb=15, upscale=4, act_mode='R', upsample_mode='pixelshuffle') print(utils_model.describe_model(model)) x = torch.randn((2, 3, 100, 100)) k_pca = torch.randn(2, 15, 1, 1) x = model(x, k_pca) print(x.shape) # run models/network_srmd.py