import torch.nn as nn
import models.basicblock as B
import torch

"""
# --------------------------------------------
# SRMD (15 conv layers)
# --------------------------------------------
Reference:
@inproceedings{zhang2018learning,
  title={Learning a single convolutional super-resolution network for multiple degradations},
  author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
  pages={3262--3271},
  year={2018}
}
http://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Learning_a_Single_CVPR_2018_paper.pdf
"""


# --------------------------------------------
# SRMD   (SRMD,   in_nc = 3+15+1 = 19)
# SRMD   (SRMDNF, in_nc = 3+15   = 18)
# --------------------------------------------
class SRMD(nn.Module):
    def __init__(self, in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle'):
        """
        # ------------------------------------
        in_nc: channel number of input, default: 3+15
        out_nc: channel number of output
        nc: channel number
        nb: total number of conv layers
        upscale: scale factor
        act_mode: batch norm + activation function; 'BR' means BN+ReLU
        upsample_mode: default 'pixelshuffle' = conv + pixelshuffle
        # ------------------------------------
        """
        super(SRMD, self).__init__()
        assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
        bias = True

        if upsample_mode == 'upconv':
            upsample_block = B.upsample_upconv
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.upsample_pixelshuffle
        elif upsample_mode == 'convtranspose':
            upsample_block = B.upsample_convtranspose
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))

        m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
        m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
        m_tail = upsample_block(nc, out_nc, mode=str(upscale), bias=bias)

        self.model = B.sequential(m_head, *m_body, m_tail)

#    def forward(self, x, k_pca):
#        m = k_pca.repeat(1, 1, x.size()[-2], x.size()[-1])
#        x = torch.cat((x, m), 1)
#        x = self.body(x)

    def forward(self, x):

        x = self.model(x)

        return x


if __name__ == '__main__':
    from utils import utils_model
    model = SRMD(in_nc=18, out_nc=3, nc=64, nb=15, upscale=4, act_mode='R', upsample_mode='pixelshuffle')
    print(utils_model.describe_model(model))

    x = torch.randn((2, 3, 100, 100))
    k_pca = torch.randn(2, 15, 1, 1)
    x = model(x, k_pca)
    print(x.shape)

    #  run models/network_srmd.py