import torch import torch.nn as nn import models.basicblock as B import numpy as np ''' # ==================== # Residual U-Net # ==================== citation: @article{zhang2020plug, title={Plug-and-Play Image Restoration with Deep Denoiser Prior}, author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, journal={arXiv preprint}, year={2020} } # ==================== ''' class UNetRes(nn.Module): def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=True): super(UNetRes, self).__init__() self.m_head = B.conv(in_nc, nc[0], bias=bias, mode='C') # downsample if downsample_mode == 'avgpool': downsample_block = B.downsample_avgpool elif downsample_mode == 'maxpool': downsample_block = B.downsample_maxpool elif downsample_mode == 'strideconv': downsample_block = B.downsample_strideconv else: raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=bias, mode='2')) self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=bias, mode='2')) self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=bias, mode='2')) self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) # upsample if upsample_mode == 'upconv': upsample_block = B.upsample_upconv elif upsample_mode == 'pixelshuffle': upsample_block = B.upsample_pixelshuffle elif upsample_mode == 'convtranspose': upsample_block = B.upsample_convtranspose else: raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=bias, mode='2'), *[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=bias, mode='2'), *[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=bias, mode='2'), *[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) self.m_tail = B.conv(nc[0], out_nc, bias=bias, mode='C') def forward(self, x0): # h, w = x.size()[-2:] # paddingBottom = int(np.ceil(h/8)*8-h) # paddingRight = int(np.ceil(w/8)*8-w) # x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) x1 = self.m_head(x0) x2 = self.m_down1(x1) x3 = self.m_down2(x2) x4 = self.m_down3(x3) x = self.m_body(x4) x = self.m_up3(x+x4) x = self.m_up2(x+x3) x = self.m_up1(x+x2) x = self.m_tail(x+x1) # x = x[..., :h, :w] return x if __name__ == '__main__': x = torch.rand(1,3,256,256) net = UNetRes() net.eval() with torch.no_grad(): y = net(x) print(y.size()) # run models/network_unet.py