|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def mesh_grid(B, H, W): |
|
|
|
x_base = torch.arange(0, W).repeat(B, H, 1) |
|
y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) |
|
|
|
base_grid = torch.stack([x_base, y_base], 1) |
|
return base_grid |
|
|
|
|
|
def norm_grid(v_grid): |
|
_, _, H, W = v_grid.size() |
|
|
|
|
|
v_grid_norm = torch.zeros_like(v_grid) |
|
v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0 |
|
v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0 |
|
return v_grid_norm.permute(0, 2, 3, 1) |
|
|
|
|
|
def get_corresponding_map(data): |
|
""" |
|
|
|
:param data: unnormalized coordinates Bx2xHxW |
|
:return: Bx1xHxW |
|
""" |
|
B, _, H, W = data.size() |
|
|
|
|
|
|
|
|
|
x = data[:, 0, :, :].view(B, -1) |
|
y = data[:, 1, :, :].view(B, -1) |
|
|
|
|
|
|
|
|
|
x1 = torch.floor(x) |
|
x_floor = x1.clamp(0, W - 1) |
|
y1 = torch.floor(y) |
|
y_floor = y1.clamp(0, H - 1) |
|
x0 = x1 + 1 |
|
x_ceil = x0.clamp(0, W - 1) |
|
y0 = y1 + 1 |
|
y_ceil = y0.clamp(0, H - 1) |
|
|
|
x_ceil_out = x0 != x_ceil |
|
y_ceil_out = y0 != y_ceil |
|
x_floor_out = x1 != x_floor |
|
y_floor_out = y1 != y_floor |
|
invalid = torch.cat([x_ceil_out | y_ceil_out, |
|
x_ceil_out | y_floor_out, |
|
x_floor_out | y_ceil_out, |
|
x_floor_out | y_floor_out], dim=1) |
|
|
|
|
|
corresponding_map = torch.zeros(B, H * W).type_as(data) |
|
indices = torch.cat([x_ceil + y_ceil * W, |
|
x_ceil + y_floor * W, |
|
x_floor + y_ceil * W, |
|
x_floor + y_floor * W], 1).long() |
|
values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)), |
|
(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)), |
|
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)), |
|
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))], |
|
1) |
|
|
|
|
|
values[invalid] = 0 |
|
|
|
corresponding_map.scatter_add_(1, indices, values) |
|
|
|
corresponding_map = corresponding_map.view(B, H, W) |
|
|
|
return corresponding_map.unsqueeze(1) |
|
|
|
|
|
def flow_warp(x, flow12, pad='border', mode='bilinear'): |
|
B, _, H, W = x.size() |
|
|
|
base_grid = mesh_grid(B, H, W).type_as(x) |
|
|
|
v_grid = norm_grid(base_grid + flow12) |
|
im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad) |
|
return im1_recons |
|
|
|
|
|
def get_occu_mask_bidirection(flow12, flow21, mask, scale=1, bias=0.5): |
|
flow21_warped = flow_warp(flow21, flow12, pad='zeros') |
|
flow12_diff = flow12 + flow21_warped |
|
mag = (flow12 * flow12).sum(1, keepdim=True) + \ |
|
(flow21_warped * flow21_warped).sum(1, keepdim=True) |
|
occ_thresh = scale * mag + bias |
|
occu = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh |
|
|
|
|
|
return occu |
|
|
|
|
|
def get_occu_mask_backward(flow21, th=0.2): |
|
B, _, H, W = flow21.size() |
|
base_grid = mesh_grid(B, H, W).type_as(flow21) |
|
|
|
corr_map = get_corresponding_map(base_grid + flow21) |
|
occu_mask = corr_map.clamp(min=0., max=1.) < th |
|
return occu_mask.float() |
|
|
|
def get_ssv_weights(cycle_corres, input, mask, scale_value): |
|
vgrid = (cycle_corres.mul(scale_value) - 1.0).permute(0,2,3,1) |
|
new_input = nn.functional.grid_sample(input, vgrid, align_corners=True, padding_mode='border') |
|
color_diff = (((input[:, :3, :, :] - new_input[:, :3, :, :]) / 255.0) ** 2).sum(1, keepdim=True) |
|
depth_diff = (((input[:, 3:, :, :] - new_input[:, 3:, :, :])) ** 2).sum(1, keepdim=True) |
|
diff = torch.mul(mask.float(), color_diff + depth_diff) |
|
return torch.exp(-diff) |
|
|