LambdaSuperRes / KAIR /models /network_ffdnet.py
cooperll
LambdaSuperRes initial commit
2514fb4
raw
history blame
2.59 kB
import numpy as np
import torch.nn as nn
import models.basicblock as B
import torch
"""
# --------------------------------------------
# FFDNet (15 or 12 conv layers)
# --------------------------------------------
Reference:
@article{zhang2018ffdnet,
title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
journal={IEEE Transactions on Image Processing},
volume={27},
number={9},
pages={4608--4622},
year={2018},
publisher={IEEE}
}
"""
# --------------------------------------------
# FFDNet
# --------------------------------------------
class FFDNet(nn.Module):
def __init__(self, in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R'):
"""
# ------------------------------------
in_nc: channel number of input
out_nc: channel number of output
nc: channel number
nb: total number of conv layers
act_mode: batch norm + activation function; 'BR' means BN+ReLU.
# ------------------------------------
# ------------------------------------
"""
super(FFDNet, self).__init__()
assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
bias = True
sf = 2
self.m_down = B.PixelUnShuffle(upscale_factor=sf)
m_head = B.conv(in_nc*sf*sf+1, nc, mode='C'+act_mode[-1], bias=bias)
m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
m_tail = B.conv(nc, out_nc*sf*sf, mode='C', bias=bias)
self.model = B.sequential(m_head, *m_body, m_tail)
self.m_up = nn.PixelShuffle(upscale_factor=sf)
def forward(self, x, sigma):
h, w = x.size()[-2:]
paddingBottom = int(np.ceil(h/2)*2-h)
paddingRight = int(np.ceil(w/2)*2-w)
x = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x)
x = self.m_down(x)
# m = torch.ones(sigma.size()[0], sigma.size()[1], x.size()[-2], x.size()[-1]).type_as(x).mul(sigma)
m = sigma.repeat(1, 1, x.size()[-2], x.size()[-1])
x = torch.cat((x, m), 1)
x = self.model(x)
x = self.m_up(x)
x = x[..., :h, :w]
return x
if __name__ == '__main__':
from utils import utils_model
model = FFDNet(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R')
print(utils_model.describe_model(model))
x = torch.randn((2,1,240,240))
sigma = torch.randn(2,1,1,1)
x = model(x, sigma)
print(x.shape)
# run models/network_ffdnet.py