import torch import torch.nn as nn import torch.nn.functional as F import torchvision from lib.models.networks.encoder import VGGEncoder # from util import util from lib.models.networks.sync_batchnorm import SynchronizedBatchNorm2d import torch.nn.utils.spectral_norm as spectral_norm def copy_state_dict(state_dict, model, strip=None, replace=None): tgt_state = model.state_dict() copied_names = set() for name, param in state_dict.items(): if strip is not None and replace is None and name.startswith(strip): name = name[len(strip):] if strip is not None and replace is not None: name = name.replace(strip, replace) if name not in tgt_state: continue if isinstance(param, torch.nn.Parameter): param = param.data if param.size() != tgt_state[name].size(): print('mismatch:', name, param.size(), tgt_state[name].size()) continue tgt_state[name].copy_(param) copied_names.add(name) missing = set(tgt_state.keys()) - copied_names if len(missing) > 0: print("missing keys in state_dict:", missing) # VGG architecter, used for the perceptual loss using a pretrained VGG network class VGG19(torch.nn.Module): def __init__(self, requires_grad=False): super(VGG19, self).__init__() vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class VGGFace19(torch.nn.Module): def __init__(self, opt, load_path="", requires_grad=False): super(VGGFace19, self).__init__() self.model = VGGEncoder(opt) self.opt = opt ckpt = torch.load(load_path) print("=> loading checkpoint '{}'".format(load_path)) copy_state_dict(ckpt, self.model.model) vgg_pretrained_features = self.model.model.features len_features = len(self.model.model.features) self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.slice6 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) for x in range(30, len_features): self.slice6.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): X = X.view(-1, self.opt.model.output_nc, self.opt.data.img_size, self.opt.data.img_size) h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) h_relu6 = self.slice6(h_relu5) out = [h_relu3, h_relu4, h_relu5, h_relu6, h_relu6] return out # Returns a function that creates a normalization function # that does not condition on semantic map def get_nonspade_norm_layer(opt, norm_type='instance'): # helper function to get # output channels of the previous layer def get_out_channel(layer): if hasattr(layer, 'out_channels'): return getattr(layer, 'out_channels') return layer.weight.size(0) # this function will be returned def add_norm_layer(layer): nonlocal norm_type if norm_type.startswith('spectral'): layer = spectral_norm(layer) subnorm_type = norm_type[len('spectral'):] else: subnorm_type = norm_type if subnorm_type == 'none' or len(subnorm_type) == 0: return layer # remove bias in the previous layer, which is meaningless # since it has no effect after normalization if getattr(layer, 'bias', None) is not None: delattr(layer, 'bias') layer.register_parameter('bias', None) if subnorm_type == 'batch': norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) elif subnorm_type == 'syncbatch': norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) elif subnorm_type == 'instance': norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) else: raise ValueError('normalization layer %s is not recognized' % subnorm_type) return nn.Sequential(layer, norm_layer) return add_norm_layer