Spaces:
Runtime error
Runtime error
| from torch import nn | |
| import torch.nn.functional as F | |
| from facerender.modules.util import kp2gaussian | |
| import torch | |
| class DownBlock2d(nn.Module): | |
| """ | |
| Simple block for processing video (encoder). | |
| """ | |
| def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): | |
| super(DownBlock2d, self).__init__() | |
| self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) | |
| if sn: | |
| self.conv = nn.utils.spectral_norm(self.conv) | |
| if norm: | |
| self.norm = nn.InstanceNorm2d(out_features, affine=True) | |
| else: | |
| self.norm = None | |
| self.pool = pool | |
| def forward(self, x): | |
| out = x | |
| out = self.conv(out) | |
| if self.norm: | |
| out = self.norm(out) | |
| out = F.leaky_relu(out, 0.2) | |
| if self.pool: | |
| out = F.avg_pool2d(out, (2, 2)) | |
| return out | |
| class Discriminator(nn.Module): | |
| """ | |
| Discriminator similar to Pix2Pix | |
| """ | |
| def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, | |
| sn=False, **kwargs): | |
| super(Discriminator, self).__init__() | |
| down_blocks = [] | |
| for i in range(num_blocks): | |
| down_blocks.append( | |
| DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), | |
| min(max_features, block_expansion * (2 ** (i + 1))), | |
| norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) | |
| self.down_blocks = nn.ModuleList(down_blocks) | |
| self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) | |
| if sn: | |
| self.conv = nn.utils.spectral_norm(self.conv) | |
| def forward(self, x): | |
| feature_maps = [] | |
| out = x | |
| for down_block in self.down_blocks: | |
| feature_maps.append(down_block(out)) | |
| out = feature_maps[-1] | |
| prediction_map = self.conv(out) | |
| return feature_maps, prediction_map | |
| class MultiScaleDiscriminator(nn.Module): | |
| """ | |
| Multi-scale (scale) discriminator | |
| """ | |
| def __init__(self, scales=(), **kwargs): | |
| super(MultiScaleDiscriminator, self).__init__() | |
| self.scales = scales | |
| discs = {} | |
| for scale in scales: | |
| discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) | |
| self.discs = nn.ModuleDict(discs) | |
| def forward(self, x): | |
| out_dict = {} | |
| for scale, disc in self.discs.items(): | |
| scale = str(scale).replace('-', '.') | |
| key = 'prediction_' + scale | |
| feature_maps, prediction_map = disc(x[key]) | |
| out_dict['feature_maps_' + scale] = feature_maps | |
| out_dict['prediction_map_' + scale] = prediction_map | |
| return out_dict | |