LambdaSuperRes / KAIR /models /network_dncnn.py
cooperll
LambdaSuperRes initial commit
2514fb4
raw
history blame
6.3 kB
import torch.nn as nn
import models.basicblock as B
"""
# --------------------------------------------
# DnCNN (20 conv layers)
# FDnCNN (20 conv layers)
# IRCNN (7 conv layers)
# --------------------------------------------
# References:
@article{zhang2017beyond,
title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising},
author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
journal={IEEE Transactions on Image Processing},
volume={26},
number={7},
pages={3142--3155},
year={2017},
publisher={IEEE}
}
@article{zhang2018ffdnet,
title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
journal={IEEE Transactions on Image Processing},
volume={27},
number={9},
pages={4608--4622},
year={2018},
publisher={IEEE}
}
# --------------------------------------------
"""
# --------------------------------------------
# DnCNN
# --------------------------------------------
class DnCNN(nn.Module):
def __init__(self, in_nc=1, out_nc=1, nc=64, nb=17, act_mode='BR'):
"""
# ------------------------------------
in_nc: channel number of input
out_nc: channel number of output
nc: channel number
nb: total number of conv layers
act_mode: batch norm + activation function; 'BR' means BN+ReLU.
# ------------------------------------
Batch normalization and residual learning are
beneficial to Gaussian denoising (especially
for a single noise level).
The residual of a noisy image corrupted by additive white
Gaussian noise (AWGN) follows a constant
Gaussian distribution which stablizes batch
normalization during training.
# ------------------------------------
"""
super(DnCNN, self).__init__()
assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
bias = True
m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
m_tail = B.conv(nc, out_nc, mode='C', bias=bias)
self.model = B.sequential(m_head, *m_body, m_tail)
def forward(self, x):
n = self.model(x)
return x-n
# --------------------------------------------
# IRCNN denoiser
# --------------------------------------------
class IRCNN(nn.Module):
def __init__(self, in_nc=1, out_nc=1, nc=64):
"""
# ------------------------------------
denoiser of IRCNN
in_nc: channel number of input
out_nc: channel number of output
nc: channel number
nb: total number of conv layers
act_mode: batch norm + activation function; 'BR' means BN+ReLU.
# ------------------------------------
Batch normalization and residual learning are
beneficial to Gaussian denoising (especially
for a single noise level).
The residual of a noisy image corrupted by additive white
Gaussian noise (AWGN) follows a constant
Gaussian distribution which stablizes batch
normalization during training.
# ------------------------------------
"""
super(IRCNN, self).__init__()
L =[]
L.append(nn.Conv2d(in_channels=in_nc, out_channels=nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=4, dilation=4, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True))
L.append(nn.ReLU(inplace=True))
L.append(nn.Conv2d(in_channels=nc, out_channels=out_nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True))
self.model = B.sequential(*L)
def forward(self, x):
n = self.model(x)
return x-n
# --------------------------------------------
# FDnCNN
# --------------------------------------------
# Compared with DnCNN, FDnCNN has three modifications:
# 1) add noise level map as input
# 2) remove residual learning and BN
# 3) train with L1 loss
# may need more training time, but will not reduce the final PSNR too much.
# --------------------------------------------
class FDnCNN(nn.Module):
def __init__(self, in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R'):
"""
in_nc: channel number of input
out_nc: channel number of output
nc: channel number
nb: total number of conv layers
act_mode: batch norm + activation function; 'BR' means BN+ReLU.
"""
super(FDnCNN, self).__init__()
assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL'
bias = True
m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)]
m_tail = B.conv(nc, out_nc, mode='C', bias=bias)
self.model = B.sequential(m_head, *m_body, m_tail)
def forward(self, x):
x = self.model(x)
return x
if __name__ == '__main__':
from utils import utils_model
import torch
model1 = DnCNN(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='BR')
print(utils_model.describe_model(model1))
model2 = FDnCNN(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R')
print(utils_model.describe_model(model2))
x = torch.randn((1, 1, 240, 240))
x1 = model1(x)
print(x1.shape)
x = torch.randn((1, 2, 240, 240))
x2 = model2(x)
print(x2.shape)
# run models/network_dncnn.py