import torch.nn as nn import models.basicblock as B """ # -------------------------------------------- # DnCNN (20 conv layers) # FDnCNN (20 conv layers) # IRCNN (7 conv layers) # -------------------------------------------- # References: @article{zhang2017beyond, title={Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising}, author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, journal={IEEE Transactions on Image Processing}, volume={26}, number={7}, pages={3142--3155}, year={2017}, publisher={IEEE} } @article{zhang2018ffdnet, title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising}, author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, journal={IEEE Transactions on Image Processing}, volume={27}, number={9}, pages={4608--4622}, year={2018}, publisher={IEEE} } # -------------------------------------------- """ # -------------------------------------------- # DnCNN # -------------------------------------------- class DnCNN(nn.Module): def __init__(self, in_nc=1, out_nc=1, nc=64, nb=17, act_mode='BR'): """ # ------------------------------------ in_nc: channel number of input out_nc: channel number of output nc: channel number nb: total number of conv layers act_mode: batch norm + activation function; 'BR' means BN+ReLU. # ------------------------------------ Batch normalization and residual learning are beneficial to Gaussian denoising (especially for a single noise level). The residual of a noisy image corrupted by additive white Gaussian noise (AWGN) follows a constant Gaussian distribution which stablizes batch normalization during training. # ------------------------------------ """ super(DnCNN, self).__init__() assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL' bias = True m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias) m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)] m_tail = B.conv(nc, out_nc, mode='C', bias=bias) self.model = B.sequential(m_head, *m_body, m_tail) def forward(self, x): n = self.model(x) return x-n # -------------------------------------------- # IRCNN denoiser # -------------------------------------------- class IRCNN(nn.Module): def __init__(self, in_nc=1, out_nc=1, nc=64): """ # ------------------------------------ denoiser of IRCNN in_nc: channel number of input out_nc: channel number of output nc: channel number nb: total number of conv layers act_mode: batch norm + activation function; 'BR' means BN+ReLU. # ------------------------------------ Batch normalization and residual learning are beneficial to Gaussian denoising (especially for a single noise level). The residual of a noisy image corrupted by additive white Gaussian noise (AWGN) follows a constant Gaussian distribution which stablizes batch normalization during training. # ------------------------------------ """ super(IRCNN, self).__init__() L =[] L.append(nn.Conv2d(in_channels=in_nc, out_channels=nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True)) L.append(nn.ReLU(inplace=True)) L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True)) L.append(nn.ReLU(inplace=True)) L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True)) L.append(nn.ReLU(inplace=True)) L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=4, dilation=4, bias=True)) L.append(nn.ReLU(inplace=True)) L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=3, dilation=3, bias=True)) L.append(nn.ReLU(inplace=True)) L.append(nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=2, dilation=2, bias=True)) L.append(nn.ReLU(inplace=True)) L.append(nn.Conv2d(in_channels=nc, out_channels=out_nc, kernel_size=3, stride=1, padding=1, dilation=1, bias=True)) self.model = B.sequential(*L) def forward(self, x): n = self.model(x) return x-n # -------------------------------------------- # FDnCNN # -------------------------------------------- # Compared with DnCNN, FDnCNN has three modifications: # 1) add noise level map as input # 2) remove residual learning and BN # 3) train with L1 loss # may need more training time, but will not reduce the final PSNR too much. # -------------------------------------------- class FDnCNN(nn.Module): def __init__(self, in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R'): """ in_nc: channel number of input out_nc: channel number of output nc: channel number nb: total number of conv layers act_mode: batch norm + activation function; 'BR' means BN+ReLU. """ super(FDnCNN, self).__init__() assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL' bias = True m_head = B.conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias) m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)] m_tail = B.conv(nc, out_nc, mode='C', bias=bias) self.model = B.sequential(m_head, *m_body, m_tail) def forward(self, x): x = self.model(x) return x if __name__ == '__main__': from utils import utils_model import torch model1 = DnCNN(in_nc=1, out_nc=1, nc=64, nb=20, act_mode='BR') print(utils_model.describe_model(model1)) model2 = FDnCNN(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') print(utils_model.describe_model(model2)) x = torch.randn((1, 1, 240, 240)) x1 = model1(x) print(x1.shape) x = torch.randn((1, 2, 240, 240)) x2 = model2(x) print(x2.shape) # run models/network_dncnn.py