Spaces:
Running
Running
| import numpy as np | |
| import torch.nn as nn | |
| import models.basicblock as B | |
| import torch | |
| """ | |
| # -------------------------------------------- | |
| # FFDNet (15 or 12 conv layers) | |
| # -------------------------------------------- | |
| Reference: | |
| @article{zhang2018ffdnet, | |
| title={FFDNet: Toward a fast and flexible solution for CNN-based image denoising}, | |
| author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, | |
| journal={IEEE Transactions on Image Processing}, | |
| volume={27}, | |
| number={9}, | |
| pages={4608--4622}, | |
| year={2018}, | |
| publisher={IEEE} | |
| } | |
| """ | |
| # -------------------------------------------- | |
| # FFDNet | |
| # -------------------------------------------- | |
| class FFDNet(nn.Module): | |
| def __init__(self, in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R'): | |
| """ | |
| # ------------------------------------ | |
| in_nc: channel number of input | |
| out_nc: channel number of output | |
| nc: channel number | |
| nb: total number of conv layers | |
| act_mode: batch norm + activation function; 'BR' means BN+ReLU. | |
| # ------------------------------------ | |
| # ------------------------------------ | |
| """ | |
| super(FFDNet, self).__init__() | |
| assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL' | |
| bias = True | |
| sf = 2 | |
| self.m_down = B.PixelUnShuffle(upscale_factor=sf) | |
| m_head = B.conv(in_nc*sf*sf+1, nc, mode='C'+act_mode[-1], bias=bias) | |
| m_body = [B.conv(nc, nc, mode='C'+act_mode, bias=bias) for _ in range(nb-2)] | |
| m_tail = B.conv(nc, out_nc*sf*sf, mode='C', bias=bias) | |
| self.model = B.sequential(m_head, *m_body, m_tail) | |
| self.m_up = nn.PixelShuffle(upscale_factor=sf) | |
| def forward(self, x, sigma): | |
| h, w = x.size()[-2:] | |
| paddingBottom = int(np.ceil(h/2)*2-h) | |
| paddingRight = int(np.ceil(w/2)*2-w) | |
| x = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) | |
| x = self.m_down(x) | |
| # m = torch.ones(sigma.size()[0], sigma.size()[1], x.size()[-2], x.size()[-1]).type_as(x).mul(sigma) | |
| m = sigma.repeat(1, 1, x.size()[-2], x.size()[-1]) | |
| x = torch.cat((x, m), 1) | |
| x = self.model(x) | |
| x = self.m_up(x) | |
| x = x[..., :h, :w] | |
| return x | |
| if __name__ == '__main__': | |
| from utils import utils_model | |
| model = FFDNet(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R') | |
| print(utils_model.describe_model(model)) | |
| x = torch.randn((2,1,240,240)) | |
| sigma = torch.randn(2,1,1,1) | |
| x = model(x, sigma) | |
| print(x.shape) | |
| # run models/network_ffdnet.py | |