Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
from lib.models.networks.encoder import VGGEncoder | |
# from util import util | |
from lib.models.networks.sync_batchnorm import SynchronizedBatchNorm2d | |
import torch.nn.utils.spectral_norm as spectral_norm | |
def copy_state_dict(state_dict, model, strip=None, replace=None): | |
tgt_state = model.state_dict() | |
copied_names = set() | |
for name, param in state_dict.items(): | |
if strip is not None and replace is None and name.startswith(strip): | |
name = name[len(strip):] | |
if strip is not None and replace is not None: | |
name = name.replace(strip, replace) | |
if name not in tgt_state: | |
continue | |
if isinstance(param, torch.nn.Parameter): | |
param = param.data | |
if param.size() != tgt_state[name].size(): | |
print('mismatch:', name, param.size(), tgt_state[name].size()) | |
continue | |
tgt_state[name].copy_(param) | |
copied_names.add(name) | |
missing = set(tgt_state.keys()) - copied_names | |
if len(missing) > 0: | |
print("missing keys in state_dict:", missing) | |
# VGG architecter, used for the perceptual loss using a pretrained VGG network | |
class VGG19(torch.nn.Module): | |
def __init__(self, requires_grad=False): | |
super(VGG19, self).__init__() | |
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
self.slice5 = torch.nn.Sequential() | |
for x in range(2): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(2, 7): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(7, 12): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(12, 21): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(21, 30): | |
self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, X): | |
h_relu1 = self.slice1(X) | |
h_relu2 = self.slice2(h_relu1) | |
h_relu3 = self.slice3(h_relu2) | |
h_relu4 = self.slice4(h_relu3) | |
h_relu5 = self.slice5(h_relu4) | |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
return out | |
class VGGFace19(torch.nn.Module): | |
def __init__(self, opt, load_path="", requires_grad=False): | |
super(VGGFace19, self).__init__() | |
self.model = VGGEncoder(opt) | |
self.opt = opt | |
ckpt = torch.load(load_path) | |
print("=> loading checkpoint '{}'".format(load_path)) | |
copy_state_dict(ckpt, self.model.model) | |
vgg_pretrained_features = self.model.model.features | |
len_features = len(self.model.model.features) | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
self.slice5 = torch.nn.Sequential() | |
self.slice6 = torch.nn.Sequential() | |
for x in range(2): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(2, 7): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(7, 12): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(12, 21): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(21, 30): | |
self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(30, len_features): | |
self.slice6.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, X): | |
X = X.view(-1, self.opt.model.output_nc, self.opt.data.img_size, self.opt.data.img_size) | |
h_relu1 = self.slice1(X) | |
h_relu2 = self.slice2(h_relu1) | |
h_relu3 = self.slice3(h_relu2) | |
h_relu4 = self.slice4(h_relu3) | |
h_relu5 = self.slice5(h_relu4) | |
h_relu6 = self.slice6(h_relu5) | |
out = [h_relu3, h_relu4, h_relu5, h_relu6, h_relu6] | |
return out | |
# Returns a function that creates a normalization function | |
# that does not condition on semantic map | |
def get_nonspade_norm_layer(opt, norm_type='instance'): | |
# helper function to get # output channels of the previous layer | |
def get_out_channel(layer): | |
if hasattr(layer, 'out_channels'): | |
return getattr(layer, 'out_channels') | |
return layer.weight.size(0) | |
# this function will be returned | |
def add_norm_layer(layer): | |
nonlocal norm_type | |
if norm_type.startswith('spectral'): | |
layer = spectral_norm(layer) | |
subnorm_type = norm_type[len('spectral'):] | |
else: | |
subnorm_type = norm_type | |
if subnorm_type == 'none' or len(subnorm_type) == 0: | |
return layer | |
# remove bias in the previous layer, which is meaningless | |
# since it has no effect after normalization | |
if getattr(layer, 'bias', None) is not None: | |
delattr(layer, 'bias') | |
layer.register_parameter('bias', None) | |
if subnorm_type == 'batch': | |
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) | |
elif subnorm_type == 'syncbatch': | |
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) | |
elif subnorm_type == 'instance': | |
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) | |
else: | |
raise ValueError('normalization layer %s is not recognized' % subnorm_type) | |
return nn.Sequential(layer, norm_layer) | |
return add_norm_layer | |