import functools import torch from torch.nn import init """ # -------------------------------------------- # select the network of G, D and F # -------------------------------------------- """ # -------------------------------------------- # Generator, netG, G # -------------------------------------------- def define_G(opt): opt_net = opt['netG'] net_type = opt_net['net_type'] # ---------------------------------------- # denoising task # ---------------------------------------- # ---------------------------------------- # DnCNN # ---------------------------------------- if net_type == 'dncnn': from models.network_dncnn import DnCNN as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], # total number of conv layers act_mode=opt_net['act_mode']) # ---------------------------------------- # Flexible DnCNN # ---------------------------------------- elif net_type == 'fdncnn': from models.network_dncnn import FDnCNN as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], # total number of conv layers act_mode=opt_net['act_mode']) # ---------------------------------------- # FFDNet # ---------------------------------------- elif net_type == 'ffdnet': from models.network_ffdnet import FFDNet as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], act_mode=opt_net['act_mode']) # ---------------------------------------- # others # ---------------------------------------- # ---------------------------------------- # super-resolution task # ---------------------------------------- # ---------------------------------------- # SRMD # ---------------------------------------- elif net_type == 'srmd': from models.network_srmd import SRMD as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # super-resolver prior of DPSR # ---------------------------------------- elif net_type == 'dpsr': from models.network_dpsr import MSRResNet_prior as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # modified SRResNet v0.0 # ---------------------------------------- elif net_type == 'msrresnet0': from models.network_msrresnet import MSRResNet0 as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # modified SRResNet v0.1 # ---------------------------------------- elif net_type == 'msrresnet1': from models.network_msrresnet import MSRResNet1 as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # RRDB # ---------------------------------------- elif net_type == 'rrdb': # RRDB from models.network_rrdb import RRDB as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # RRDBNet # ---------------------------------------- elif net_type == 'rrdbnet': # RRDBNet from models.network_rrdbnet import RRDBNet as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], sf=opt_net['scale']) # ---------------------------------------- # IMDB # ---------------------------------------- elif net_type == 'imdn': # IMDB from models.network_imdn import IMDN as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], upscale=opt_net['scale'], act_mode=opt_net['act_mode'], upsample_mode=opt_net['upsample_mode']) # ---------------------------------------- # USRNet # ---------------------------------------- elif net_type == 'usrnet': # USRNet from models.network_usrnet import USRNet as net netG = net(n_iter=opt_net['n_iter'], h_nc=opt_net['h_nc'], in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], act_mode=opt_net['act_mode'], downsample_mode=opt_net['downsample_mode'], upsample_mode=opt_net['upsample_mode'] ) # ---------------------------------------- # Deep Residual U-Net (drunet) # ---------------------------------------- elif net_type == 'drunet': from models.network_unet import UNetRes as net netG = net(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nc=opt_net['nc'], nb=opt_net['nb'], act_mode=opt_net['act_mode'], downsample_mode=opt_net['downsample_mode'], upsample_mode=opt_net['upsample_mode'], bias=opt_net['bias']) # ---------------------------------------- # SwinIR # ---------------------------------------- elif net_type == 'swinir': from models.network_swinir import SwinIR as net netG = net(upscale=opt_net['upscale'], in_chans=opt_net['in_chans'], img_size=opt_net['img_size'], window_size=opt_net['window_size'], img_range=opt_net['img_range'], depths=opt_net['depths'], embed_dim=opt_net['embed_dim'], num_heads=opt_net['num_heads'], mlp_ratio=opt_net['mlp_ratio'], upsampler=opt_net['upsampler'], resi_connection=opt_net['resi_connection']) # ---------------------------------------- # VRT # ---------------------------------------- elif net_type == 'vrt': from models.network_vrt import VRT as net netG = net(upscale=opt_net['upscale'], img_size=opt_net['img_size'], window_size=opt_net['window_size'], depths=opt_net['depths'], indep_reconsts=opt_net['indep_reconsts'], embed_dims=opt_net['embed_dims'], num_heads=opt_net['num_heads'], spynet_path=opt_net['spynet_path'], pa_frames=opt_net['pa_frames'], deformable_groups=opt_net['deformable_groups'], nonblind_denoising=opt_net['nonblind_denoising'], use_checkpoint_attn=opt_net['use_checkpoint_attn'], use_checkpoint_ffn=opt_net['use_checkpoint_ffn'], no_checkpoint_attn_blocks=opt_net['no_checkpoint_attn_blocks'], no_checkpoint_ffn_blocks=opt_net['no_checkpoint_ffn_blocks']) # ---------------------------------------- # others # ---------------------------------------- # TODO else: raise NotImplementedError('netG [{:s}] is not found.'.format(net_type)) # ---------------------------------------- # initialize weights # ---------------------------------------- if opt['is_train']: init_weights(netG, init_type=opt_net['init_type'], init_bn_type=opt_net['init_bn_type'], gain=opt_net['init_gain']) return netG # -------------------------------------------- # Discriminator, netD, D # -------------------------------------------- def define_D(opt): opt_net = opt['netD'] net_type = opt_net['net_type'] # ---------------------------------------- # discriminator_vgg_96 # ---------------------------------------- if net_type == 'discriminator_vgg_96': from models.network_discriminator import Discriminator_VGG_96 as discriminator netD = discriminator(in_nc=opt_net['in_nc'], base_nc=opt_net['base_nc'], ac_type=opt_net['act_mode']) # ---------------------------------------- # discriminator_vgg_128 # ---------------------------------------- elif net_type == 'discriminator_vgg_128': from models.network_discriminator import Discriminator_VGG_128 as discriminator netD = discriminator(in_nc=opt_net['in_nc'], base_nc=opt_net['base_nc'], ac_type=opt_net['act_mode']) # ---------------------------------------- # discriminator_vgg_192 # ---------------------------------------- elif net_type == 'discriminator_vgg_192': from models.network_discriminator import Discriminator_VGG_192 as discriminator netD = discriminator(in_nc=opt_net['in_nc'], base_nc=opt_net['base_nc'], ac_type=opt_net['act_mode']) # ---------------------------------------- # discriminator_vgg_128_SN # ---------------------------------------- elif net_type == 'discriminator_vgg_128_SN': from models.network_discriminator import Discriminator_VGG_128_SN as discriminator netD = discriminator() elif net_type == 'discriminator_patchgan': from models.network_discriminator import Discriminator_PatchGAN as discriminator netD = discriminator(input_nc=opt_net['in_nc'], ndf=opt_net['base_nc'], n_layers=opt_net['n_layers'], norm_type=opt_net['norm_type']) elif net_type == 'discriminator_unet': from models.network_discriminator import Discriminator_UNet as discriminator netD = discriminator(input_nc=opt_net['in_nc'], ndf=opt_net['base_nc']) else: raise NotImplementedError('netD [{:s}] is not found.'.format(net_type)) # ---------------------------------------- # initialize weights # ---------------------------------------- init_weights(netD, init_type=opt_net['init_type'], init_bn_type=opt_net['init_bn_type'], gain=opt_net['init_gain']) return netD # -------------------------------------------- # VGGfeature, netF, F # -------------------------------------------- def define_F(opt, use_bn=False): device = torch.device('cuda' if opt['gpu_ids'] else 'cpu') from models.network_feature import VGGFeatureExtractor # pytorch pretrained VGG19-54, before ReLU. if use_bn: feature_layer = 49 else: feature_layer = 34 netF = VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, use_input_norm=True, device=device) netF.eval() # No need to train, but need BP to input return netF """ # -------------------------------------------- # weights initialization # -------------------------------------------- """ def init_weights(net, init_type='xavier_uniform', init_bn_type='uniform', gain=1): """ # Kai Zhang, https://github.com/cszn/KAIR # # Args: # init_type: # default, none: pass init_weights # normal; normal; xavier_normal; xavier_uniform; # kaiming_normal; kaiming_uniform; orthogonal # init_bn_type: # uniform; constant # gain: # 0.2 """ def init_fn(m, init_type='xavier_uniform', init_bn_type='uniform', gain=1): classname = m.__class__.__name__ if classname.find('Conv') != -1 or classname.find('Linear') != -1: if init_type == 'normal': init.normal_(m.weight.data, 0, 0.1) m.weight.data.clamp_(-1, 1).mul_(gain) elif init_type == 'uniform': init.uniform_(m.weight.data, -0.2, 0.2) m.weight.data.mul_(gain) elif init_type == 'xavier_normal': init.xavier_normal_(m.weight.data, gain=gain) m.weight.data.clamp_(-1, 1) elif init_type == 'xavier_uniform': init.xavier_uniform_(m.weight.data, gain=gain) elif init_type == 'kaiming_normal': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') m.weight.data.clamp_(-1, 1).mul_(gain) elif init_type == 'kaiming_uniform': init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') m.weight.data.mul_(gain) elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_type)) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: if init_bn_type == 'uniform': # preferred if m.affine: init.uniform_(m.weight.data, 0.1, 1.0) init.constant_(m.bias.data, 0.0) elif init_bn_type == 'constant': if m.affine: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0) else: raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_bn_type)) if init_type not in ['default', 'none']: print('Initialization method [{:s} + {:s}], gain is [{:.2f}]'.format(init_type, init_bn_type, gain)) fn = functools.partial(init_fn, init_type=init_type, init_bn_type=init_bn_type, gain=gain) net.apply(fn) else: print('Pass this initialization! Initialization was done during network definition!')