Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import models.basicblock as B | |
import numpy as np | |
''' | |
# ==================== | |
# Residual U-Net | |
# ==================== | |
citation: | |
@article{zhang2020plug, | |
title={Plug-and-Play Image Restoration with Deep Denoiser Prior}, | |
author={Zhang, Kai and Li, Yawei and Zuo, Wangmeng and Zhang, Lei and Van Gool, Luc and Timofte, Radu}, | |
journal={arXiv preprint}, | |
year={2020} | |
} | |
# ==================== | |
''' | |
class UNetRes(nn.Module): | |
def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=True): | |
super(UNetRes, self).__init__() | |
self.m_head = B.conv(in_nc, nc[0], bias=bias, mode='C') | |
# downsample | |
if downsample_mode == 'avgpool': | |
downsample_block = B.downsample_avgpool | |
elif downsample_mode == 'maxpool': | |
downsample_block = B.downsample_maxpool | |
elif downsample_mode == 'strideconv': | |
downsample_block = B.downsample_strideconv | |
else: | |
raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) | |
self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=bias, mode='2')) | |
self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=bias, mode='2')) | |
self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=bias, mode='2')) | |
self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) | |
# upsample | |
if upsample_mode == 'upconv': | |
upsample_block = B.upsample_upconv | |
elif upsample_mode == 'pixelshuffle': | |
upsample_block = B.upsample_pixelshuffle | |
elif upsample_mode == 'convtranspose': | |
upsample_block = B.upsample_convtranspose | |
else: | |
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) | |
self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=bias, mode='2'), *[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) | |
self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=bias, mode='2'), *[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) | |
self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=bias, mode='2'), *[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) | |
self.m_tail = B.conv(nc[0], out_nc, bias=bias, mode='C') | |
def forward(self, x0): | |
# h, w = x.size()[-2:] | |
# paddingBottom = int(np.ceil(h/8)*8-h) | |
# paddingRight = int(np.ceil(w/8)*8-w) | |
# x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) | |
x1 = self.m_head(x0) | |
x2 = self.m_down1(x1) | |
x3 = self.m_down2(x2) | |
x4 = self.m_down3(x3) | |
x = self.m_body(x4) | |
x = self.m_up3(x+x4) | |
x = self.m_up2(x+x3) | |
x = self.m_up1(x+x2) | |
x = self.m_tail(x+x1) | |
# x = x[..., :h, :w] | |
return x | |
if __name__ == '__main__': | |
x = torch.rand(1,3,256,256) | |
net = UNetRes() | |
net.eval() | |
with torch.no_grad(): | |
y = net(x) | |
print(y.size()) | |
# run models/network_unet.py | |