yinwentao
DockerFile
8d34f50
import torch
import torch.nn as nn
import torch.nn.functional as F
def mesh_grid(B, H, W):
# mesh grid
x_base = torch.arange(0, W).repeat(B, H, 1) # BHW
y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) # BHW
base_grid = torch.stack([x_base, y_base], 1) # B2HW
return base_grid
def norm_grid(v_grid):
_, _, H, W = v_grid.size()
# scale grid to [-1,1]
v_grid_norm = torch.zeros_like(v_grid)
v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0
v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0
return v_grid_norm.permute(0, 2, 3, 1) # BHW2
def get_corresponding_map(data):
"""
:param data: unnormalized coordinates Bx2xHxW
:return: Bx1xHxW
"""
B, _, H, W = data.size()
# x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W)
# y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1)
x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W)
y = data[:, 1, :, :].view(B, -1)
# invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN
# invalid = invalid.repeat([1, 4])
x1 = torch.floor(x)
x_floor = x1.clamp(0, W - 1)
y1 = torch.floor(y)
y_floor = y1.clamp(0, H - 1)
x0 = x1 + 1
x_ceil = x0.clamp(0, W - 1)
y0 = y1 + 1
y_ceil = y0.clamp(0, H - 1)
x_ceil_out = x0 != x_ceil
y_ceil_out = y0 != y_ceil
x_floor_out = x1 != x_floor
y_floor_out = y1 != y_floor
invalid = torch.cat([x_ceil_out | y_ceil_out,
x_ceil_out | y_floor_out,
x_floor_out | y_ceil_out,
x_floor_out | y_floor_out], dim=1)
# encode coordinates, since the scatter function can only index along one axis
corresponding_map = torch.zeros(B, H * W).type_as(data)
indices = torch.cat([x_ceil + y_ceil * W,
x_ceil + y_floor * W,
x_floor + y_ceil * W,
x_floor + y_floor * W], 1).long() # BxN (N=4*H*W)
values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)),
(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)),
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)),
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))],
1)
# values = torch.ones_like(values)
values[invalid] = 0
corresponding_map.scatter_add_(1, indices, values)
# decode coordinates
corresponding_map = corresponding_map.view(B, H, W)
return corresponding_map.unsqueeze(1)
def flow_warp(x, flow12, pad='border', mode='bilinear'):
B, _, H, W = x.size()
base_grid = mesh_grid(B, H, W).type_as(x) # B2HW
v_grid = norm_grid(base_grid + flow12) # BHW2
im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad)
return im1_recons
def get_occu_mask_bidirection(flow12, flow21, mask, scale=1, bias=0.5):
flow21_warped = flow_warp(flow21, flow12, pad='zeros')
flow12_diff = flow12 + flow21_warped
mag = (flow12 * flow12).sum(1, keepdim=True) + \
(flow21_warped * flow21_warped).sum(1, keepdim=True)
occ_thresh = scale * mag + bias
occu = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh
# soft_occu = 1.0 / (1 + torch.exp(diff) / 5.0)
# print("forward:", diff.max(), diff.min())
return occu
def get_occu_mask_backward(flow21, th=0.2):
B, _, H, W = flow21.size()
base_grid = mesh_grid(B, H, W).type_as(flow21) # B2HW
corr_map = get_corresponding_map(base_grid + flow21) # BHW
occu_mask = corr_map.clamp(min=0., max=1.) < th
return occu_mask.float()
def get_ssv_weights(cycle_corres, input, mask, scale_value):
vgrid = (cycle_corres.mul(scale_value) - 1.0).permute(0,2,3,1)
new_input = nn.functional.grid_sample(input, vgrid, align_corners=True, padding_mode='border')
color_diff = (((input[:, :3, :, :] - new_input[:, :3, :, :]) / 255.0) ** 2).sum(1, keepdim=True)
depth_diff = (((input[:, 3:, :, :] - new_input[:, 3:, :, :])) ** 2).sum(1, keepdim=True)
diff = torch.mul(mask.float(), color_diff + depth_diff) #(N, 1, H, W)
return torch.exp(-diff)