File size: 7,555 Bytes
2492d81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from torch import nn
import torch.nn.functional as F
import torch
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
from modules.util import to_homogeneous, from_homogeneous, UpBlock2d, TPS
import math
class DenseMotionNetwork(nn.Module):
Module that estimating an optical flow and multi-resolution occlusion masks
from K TPS transformations and an affine transformation.
def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels,
scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01):
super(DenseMotionNetwork, self).__init__()
if scale_factor != 1:
self.down = AntiAliasInterpolation2d(num_channels, scale_factor)
self.scale_factor = scale_factor
self.multi_mask = multi_mask
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_channels * (num_tps+1) + num_tps*5+1),
max_features=max_features, num_blocks=num_blocks)
hourglass_output_size = self.hourglass.out_channels
self.maps = nn.Conv2d(hourglass_output_size[-1], num_tps + 1, kernel_size=(7, 7), padding=(3, 3))
if multi_mask:
up = []
self.up_nums = int(math.log(1/scale_factor, 2))
self.occlusion_num = 4
channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)]
for i in range(self.up_nums):
up.append(UpBlock2d(channel[i], channel[i]//2, kernel_size=3, padding=1))
self.up = nn.ModuleList(up)
channel = [hourglass_output_size[-i-1] for i in range(self.occlusion_num-self.up_nums)[::-1]]
for i in range(self.up_nums):
occlusion = []
for i in range(self.occlusion_num):
occlusion.append(nn.Conv2d(channel[i], 1, kernel_size=(7, 7), padding=(3, 3)))
self.occlusion = nn.ModuleList(occlusion)
occlusion = [nn.Conv2d(hourglass_output_size[-1], 1, kernel_size=(7, 7), padding=(3, 3))]
self.occlusion = nn.ModuleList(occlusion)
self.num_tps = num_tps = bg
self.kp_variance = kp_variance
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
spatial_size = source_image.shape[2:]
gaussian_driving = kp2gaussian(kp_driving['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
gaussian_source = kp2gaussian(kp_source['fg_kp'], spatial_size=spatial_size, kp_variance=self.kp_variance)
heatmap = gaussian_driving - gaussian_source
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()).to(heatmap.device)
heatmap =[zeros, heatmap], dim=1)
return heatmap
def create_transformations(self, source_image, kp_driving, kp_source, bg_param):
# K TPS transformaions
bs, _, h, w = source_image.shape
kp_1 = kp_driving['fg_kp']
kp_2 = kp_source['fg_kp']
kp_1 = kp_1.view(bs, -1, 5, 2)
kp_2 = kp_2.view(bs, -1, 5, 2)
trans = TPS(mode = 'kp', bs = bs, kp_1 = kp_1, kp_2 = kp_2)
driving_to_source = trans.transform_frame(source_image)
identity_grid = make_coordinate_grid((h, w), type=kp_1.type()).to(kp_1.device)
identity_grid = identity_grid.view(1, 1, h, w, 2)
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
# affine background transformation
if not (bg_param is None):
identity_grid = to_homogeneous(identity_grid)
identity_grid = torch.matmul(bg_param.view(bs, 1, 1, 1, 3, 3), identity_grid.unsqueeze(-1)).squeeze(-1)
identity_grid = from_homogeneous(identity_grid)
transformations =[identity_grid, driving_to_source], dim=1)
return transformations
def create_deformed_source_image(self, source_image, transformations):
bs, _, h, w = source_image.shape
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_tps + 1, 1, 1, 1, 1)
source_repeat = source_repeat.view(bs * (self.num_tps + 1), -1, h, w)
transformations = transformations.view((bs * (self.num_tps + 1), h, w, -1))
deformed = F.grid_sample(source_repeat, transformations, align_corners=True)
deformed = deformed.view((bs, self.num_tps+1, -1, h, w))
return deformed
def dropout_softmax(self, X, P):
Dropout for TPS transformations. Eq(7) and Eq(8) in the paper.
drop = (torch.rand(X.shape[0],X.shape[1]) < (1-P)).type(X.type()).to(X.device)
drop[..., 0] = 1
drop = drop.repeat(X.shape[2],X.shape[3],1,1).permute(2,3,0,1)
maxx = X.max(1).values.unsqueeze_(1)
X = X - maxx
X_exp = X.exp()
X[:,1:,...] /= (1-P)
mask_bool =(drop == 0)
X_exp = X_exp.masked_fill(mask_bool, 0)
partition = X_exp.sum(dim=1, keepdim=True) + 1e-6
return X_exp / partition
def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_flag=False, dropout_p = 0):
if self.scale_factor != 1:
source_image = self.down(source_image)
bs, _, h, w = source_image.shape
out_dict = dict()
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
transformations = self.create_transformations(source_image, kp_driving, kp_source, bg_param)
deformed_source = self.create_deformed_source_image(source_image, transformations)
out_dict['deformed_source'] = deformed_source
# out_dict['transformations'] = transformations
deformed_source = deformed_source.view(bs,-1,h,w)
input =[heatmap_representation, deformed_source], dim=1)
input = input.view(bs, -1, h, w)
prediction = self.hourglass(input, mode = 1)
contribution_maps = self.maps(prediction[-1])
contribution_maps = self.dropout_softmax(contribution_maps, dropout_p)
contribution_maps = F.softmax(contribution_maps, dim=1)
out_dict['contribution_maps'] = contribution_maps
# Combine the K+1 transformations
# Eq(6) in the paper
contribution_maps = contribution_maps.unsqueeze(2)
transformations = transformations.permute(0, 1, 4, 2, 3)
deformation = (transformations * contribution_maps).sum(dim=1)
deformation = deformation.permute(0, 2, 3, 1)
out_dict['deformation'] = deformation # Optical Flow
occlusion_map = []
if self.multi_mask:
for i in range(self.occlusion_num-self.up_nums):
prediction = prediction[-1]
for i in range(self.up_nums):
prediction = self.up[i](prediction)
out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks
return out_dict