Spaces:
Runtime error
Runtime error
File size: 6,128 Bytes
8fd2f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/flow_estimation
import torch
import torch.nn as nn
import torch.nn.functional as F
from .warplayer import warp
from .refine import *
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class Head(nn.Module):
def __init__(self, in_planes, scale, c, in_else=17):
super(Head, self).__init__()
self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2))
self.scale = scale
self.conv = nn.Sequential(
conv(in_planes*2 // (4*4) + in_else, c),
conv(c, c),
conv(c, 5),
)
def forward(self, motion_feature, x, flow): # /16 /8 /4
motion_feature = self.upsample(motion_feature) #/4 /2 /1
if self.scale != 4:
x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False)
if flow != None:
if self.scale != 4:
flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale
x = torch.cat((x, flow), 1)
x = self.conv(torch.cat([motion_feature, x], 1))
if self.scale != 4:
x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False)
flow = x[:, :4] * (self.scale // 4)
else:
flow = x[:, :4]
mask = x[:, 4:5]
return flow, mask
class MultiScaleFlow(nn.Module):
def __init__(self, backbone, **kargs):
super(MultiScaleFlow, self).__init__()
self.flow_num_stage = len(kargs['hidden_dims'])
self.feature_bone = backbone
self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i],
kargs['scales'][-1-i],
kargs['hidden_dims'][-1-i],
6 if i==0 else 17)
for i in range(self.flow_num_stage)])
self.unet = Unet(kargs['c'] * 2)
def warp_features(self, xs, flow):
y0 = []
y1 = []
B = xs[0].size(0) // 2
for x in xs:
y0.append(warp(x[:B], flow[:, 0:2]))
y1.append(warp(x[B:], flow[:, 2:4]))
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
return y0, y1
def calculate_flow(self, imgs, timestep, af=None, mf=None):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
B = img0.size(0)
flow, mask = None, None
# appearence_features & motion_features
if (af is None) or (mf is None):
af, mf = self.feature_bone(img0, img1)
for i in range(self.flow_num_stage):
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda()
if flow != None:
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
flow_, mask_ = self.block[i](
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1),
flow
)
flow = flow + flow_
mask = mask + mask_
else:
flow, mask = self.block[i](
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
torch.cat((img0, img1), 1),
None
)
return flow, mask
def coraseWarp_and_Refine(self, imgs, af, flow, mask):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
c0, c1 = self.warp_features(af, flow)
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
mask_ = torch.sigmoid(mask)
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
pred = torch.clamp(merged + res, 0, 1)
return pred
# Actually consist of 'calculate_flow' and 'coraseWarp_and_Refine'
def forward(self, x, timestep=0.5):
img0, img1 = x[:, :3], x[:, 3:6]
B = x.size(0)
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
# appearence_features & motion_features
af, mf = self.feature_bone(img0, img1)
for i in range(self.flow_num_stage):
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda()
if flow != None:
flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-timestep)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow)
flow = flow + flow_d
mask = mask + mask_d
else:
flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
torch.cat((img0, img1), 1), None)
mask_list.append(torch.sigmoid(mask))
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
c0, c1 = self.warp_features(af, flow)
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
pred = torch.clamp(merged[-1] + res, 0, 1)
return flow_list, mask_list, merged, pred |