Spaces:
Runtime error
Runtime error
| from collections import OrderedDict | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import swapae.util as util | |
| from swapae.models.networks import BaseNetwork | |
| from swapae.models.networks.stylegan2_layers import ConvLayer, ResBlock, EqualLinear | |
| class BasePatchDiscriminator(BaseNetwork): | |
| def modify_commandline_options(parser, is_train): | |
| parser.add_argument("--netPatchD_scale_capacity", default=4.0, type=float) | |
| parser.add_argument("--netPatchD_max_nc", default=256 + 128, type=int) | |
| parser.add_argument("--patch_size", default=128, type=int) | |
| parser.add_argument("--max_num_tiles", default=8, type=int) | |
| parser.add_argument("--patch_random_transformation", | |
| type=util.str2bool, nargs='?', const=True, default=False) | |
| return parser | |
| def __init__(self, opt): | |
| super().__init__(opt) | |
| #self.visdom = util.Visualizer(opt) | |
| def needs_regularization(self): | |
| return False | |
| def extract_features(self, patches): | |
| raise NotImplementedError() | |
| def discriminate_features(self, feature1, feature2): | |
| raise NotImplementedError() | |
| def apply_random_transformation(self, patches): | |
| B, ntiles, C, H, W = patches.size() | |
| patches = patches.view(B * ntiles, C, H, W) | |
| before = patches | |
| transformer = util.RandomSpatialTransformer(self.opt, B * ntiles) | |
| patches = transformer.forward_transform(patches, (self.opt.patch_size, self.opt.patch_size)) | |
| #self.visdom.display_current_results({'before': before, | |
| # 'after': patches}, 0, save_result=False) | |
| return patches.view(B, ntiles, C, H, W) | |
| def sample_patches_old(self, img, indices): | |
| B, C, H, W = img.size() | |
| s = self.opt.patch_size | |
| if H % s > 0 or W % s > 0: | |
| y_offset = torch.randint(H % s, (), device=img.device) | |
| x_offset = torch.randint(W % s, (), device=img.device) | |
| img = img[:, :, | |
| y_offset:y_offset + s * (H // s), | |
| x_offset:x_offset + s * (W // s)] | |
| img = img.view(B, C, H//s, s, W//s, s) | |
| ntiles = (H // s) * (W // s) | |
| tiles = img.permute(0, 2, 4, 1, 3, 5).reshape(B, ntiles, C, s, s) | |
| if indices is None: | |
| indices = torch.randperm(ntiles, device=img.device)[:self.opt.max_num_tiles] | |
| return self.apply_random_transformation(tiles[:, indices]), indices | |
| else: | |
| return self.apply_random_transformation(tiles[:, indices]) | |
| def forward(self, real, fake, fake_only=False): | |
| assert real is not None | |
| real_patches, patch_ids = self.sample_patches(real, None) | |
| if fake is None: | |
| real_patches.requires_grad_() | |
| real_feat = self.extract_features(real_patches) | |
| bs = real.size(0) | |
| if fake is None or not fake_only: | |
| pred_real = self.discriminate_features( | |
| real_feat, | |
| torch.roll(real_feat, 1, 1)) | |
| pred_real = pred_real.view(bs, -1) | |
| if fake is not None: | |
| fake_patches = self.sample_patches(fake, patch_ids) | |
| #self.visualizer.display_current_results({'real_A': real_patches[0], | |
| # 'real_B': torch.roll(fake_patches, 1, 1)[0]}, 0, False, max_num_images=16) | |
| fake_feat = self.extract_features(fake_patches) | |
| pred_fake = self.discriminate_features( | |
| real_feat, | |
| torch.roll(fake_feat, 1, 1)) | |
| pred_fake = pred_fake.view(bs, -1) | |
| if fake is None: | |
| return pred_real, real_patches | |
| elif fake_only: | |
| return pred_fake | |
| else: | |
| return pred_real, pred_fake | |
| class StyleGAN2PatchDiscriminator(BasePatchDiscriminator): | |
| def modify_commandline_options(parser, is_train): | |
| BasePatchDiscriminator.modify_commandline_options(parser, is_train) | |
| return parser | |
| def __init__(self, opt): | |
| super().__init__(opt) | |
| channel_multiplier = self.opt.netPatchD_scale_capacity | |
| size = self.opt.patch_size | |
| channels = { | |
| 4: min(self.opt.netPatchD_max_nc, int(256 * channel_multiplier)), | |
| 8: min(self.opt.netPatchD_max_nc, int(128 * channel_multiplier)), | |
| 16: min(self.opt.netPatchD_max_nc, int(64 * channel_multiplier)), | |
| 32: int(32 * channel_multiplier), | |
| 64: int(16 * channel_multiplier), | |
| 128: int(8 * channel_multiplier), | |
| 256: int(4 * channel_multiplier), | |
| } | |
| log_size = int(math.ceil(math.log(size, 2))) | |
| in_channel = channels[2 ** log_size] | |
| blur_kernel = [1, 3, 3, 1] if self.opt.use_antialias else [1] | |
| convs = [('0', ConvLayer(3, in_channel, 3))] | |
| for i in range(log_size, 2, -1): | |
| out_channel = channels[2 ** (i - 1)] | |
| layer_name = str(7 - i) if i <= 6 else "%dx%d" % (2 ** i, 2 ** i) | |
| convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel))) | |
| in_channel = out_channel | |
| convs.append(('5', ResBlock(in_channel, self.opt.netPatchD_max_nc * 2, downsample=False))) | |
| convs.append(('6', ConvLayer(self.opt.netPatchD_max_nc * 2, self.opt.netPatchD_max_nc, 3, pad=0))) | |
| self.convs = nn.Sequential(OrderedDict(convs)) | |
| out_dim = 1 | |
| pairlinear1 = EqualLinear(channels[4] * 2 * 2, 2048, activation='fused_lrelu') | |
| pairlinear2 = EqualLinear(2048, 2048, activation='fused_lrelu') | |
| pairlinear3 = EqualLinear(2048, 1024, activation='fused_lrelu') | |
| pairlinear4 = EqualLinear(1024, out_dim) | |
| self.pairlinear = nn.Sequential(pairlinear1, pairlinear2, pairlinear3, pairlinear4) | |
| def extract_features(self, patches, aggregate=False): | |
| if patches.ndim == 5: | |
| B, T, C, H, W = patches.size() | |
| flattened_patches = patches.flatten(0, 1) | |
| else: | |
| B, C, H, W = patches.size() | |
| T = patches.size(1) | |
| flattened_patches = patches | |
| features = self.convs(flattened_patches) | |
| features = features.view(B, T, features.size(1), features.size(2), features.size(3)) | |
| if aggregate: | |
| features = features.mean(1, keepdim=True).expand(-1, T, -1, -1, -1) | |
| return features.flatten(0, 1) | |
| def extract_layerwise_features(self, image): | |
| feats = [image] | |
| for m in self.convs: | |
| feats.append(m(feats[-1])) | |
| return feats | |
| def discriminate_features(self, feature1): | |
| feature1 = feature1.flatten(1) | |
| #feature2 = feature2.flatten(1) | |
| #out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) | |
| out = self.pairlinear(feature1) | |
| return out | |
| """ | |
| def discriminate_features(self, feature1, feature2): | |
| feature1 = feature1.flatten(1) | |
| feature2 = feature2.flatten(1) | |
| out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) | |
| return out | |
| """ | |
| class StyleGAN2COGANPatchDiscriminator(BasePatchDiscriminator): | |
| def modify_commandline_options(parser, is_train): | |
| BasePatchDiscriminator.modify_commandline_options(parser, is_train) | |
| return parser | |
| def __init__(self, opt): | |
| super().__init__(opt) | |
| channel_multiplier = self.opt.netPatchD_scale_capacity | |
| size = self.opt.patch_size | |
| channels = { | |
| 4: min(self.opt.netPatchD_max_nc, int(256 * channel_multiplier)), | |
| 8: min(self.opt.netPatchD_max_nc, int(128 * channel_multiplier)), | |
| 16: min(self.opt.netPatchD_max_nc, int(64 * channel_multiplier)), | |
| 32: int(32 * channel_multiplier), | |
| 64: int(16 * channel_multiplier), | |
| 128: int(8 * channel_multiplier), | |
| 256: int(4 * channel_multiplier), | |
| } | |
| log_size = int(math.ceil(math.log(size, 2))) | |
| in_channel = channels[2 ** log_size] | |
| blur_kernel = [1, 3, 3, 1] if self.opt.use_antialias else [1] | |
| convs = [('0', ConvLayer(3, in_channel, 3))] | |
| for i in range(log_size, 2, -1): | |
| out_channel = channels[2 ** (i - 1)] | |
| layer_name = str(7 - i) if i <= 6 else "%dx%d" % (2 ** i, 2 ** i) | |
| convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel))) | |
| in_channel = out_channel | |
| convs.append(('5', ResBlock(in_channel, self.opt.netPatchD_max_nc * 2, downsample=False))) | |
| convs.append(('6', ConvLayer(self.opt.netPatchD_max_nc * 2, self.opt.netPatchD_max_nc, 3, pad=0))) | |
| self.convs = nn.Sequential(OrderedDict(convs)) | |
| out_dim = 1 | |
| pairlinear1 = EqualLinear(channels[4] * 2 * 2 * 2, 2048, activation='fused_lrelu') | |
| pairlinear2 = EqualLinear(2048, 2048, activation='fused_lrelu') | |
| pairlinear3 = EqualLinear(2048, 1024, activation='fused_lrelu') | |
| pairlinear4 = EqualLinear(1024, out_dim) | |
| self.pairlinear = nn.Sequential(pairlinear1, pairlinear2, pairlinear3, pairlinear4) | |
| def extract_features(self, patches, aggregate=False): | |
| if patches.ndim == 5: | |
| B, T, C, H, W = patches.size() | |
| flattened_patches = patches.flatten(0, 1) | |
| else: | |
| B, C, H, W = patches.size() | |
| T = patches.size(1) | |
| flattened_patches = patches | |
| features = self.convs(flattened_patches) | |
| features = features.view(B, T, features.size(1), features.size(2), features.size(3)) | |
| if aggregate: | |
| features = features.mean(1, keepdim=True).expand(-1, T, -1, -1, -1) | |
| return features.flatten(0, 1) | |
| def extract_layerwise_features(self, image): | |
| feats = [image] | |
| for m in self.convs: | |
| feats.append(m(feats[-1])) | |
| return feats | |
| def discriminate_features(self, feature1, feature2): | |
| feature1 = feature1.flatten(1) | |
| feature2 = feature2.flatten(1) | |
| out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) | |
| return out | |