import torch import torch.nn as nn import torch.nn.functional as F def mesh_grid(B, H, W): # mesh grid x_base = torch.arange(0, W).repeat(B, H, 1) # BHW y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) # BHW base_grid = torch.stack([x_base, y_base], 1) # B2HW return base_grid def norm_grid(v_grid): _, _, H, W = v_grid.size() # scale grid to [-1,1] v_grid_norm = torch.zeros_like(v_grid) v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0 v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0 return v_grid_norm.permute(0, 2, 3, 1) # BHW2 def get_corresponding_map(data): """ :param data: unnormalized coordinates Bx2xHxW :return: Bx1xHxW """ B, _, H, W = data.size() # x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W) # y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1) x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W) y = data[:, 1, :, :].view(B, -1) # invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN # invalid = invalid.repeat([1, 4]) x1 = torch.floor(x) x_floor = x1.clamp(0, W - 1) y1 = torch.floor(y) y_floor = y1.clamp(0, H - 1) x0 = x1 + 1 x_ceil = x0.clamp(0, W - 1) y0 = y1 + 1 y_ceil = y0.clamp(0, H - 1) x_ceil_out = x0 != x_ceil y_ceil_out = y0 != y_ceil x_floor_out = x1 != x_floor y_floor_out = y1 != y_floor invalid = torch.cat([x_ceil_out | y_ceil_out, x_ceil_out | y_floor_out, x_floor_out | y_ceil_out, x_floor_out | y_floor_out], dim=1) # encode coordinates, since the scatter function can only index along one axis corresponding_map = torch.zeros(B, H * W).type_as(data) indices = torch.cat([x_ceil + y_ceil * W, x_ceil + y_floor * W, x_floor + y_ceil * W, x_floor + y_floor * W], 1).long() # BxN (N=4*H*W) values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)), (1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)), (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)), (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))], 1) # values = torch.ones_like(values) values[invalid] = 0 corresponding_map.scatter_add_(1, indices, values) # decode coordinates corresponding_map = corresponding_map.view(B, H, W) return corresponding_map.unsqueeze(1) def flow_warp(x, flow12, pad='border', mode='bilinear'): B, _, H, W = x.size() base_grid = mesh_grid(B, H, W).type_as(x) # B2HW v_grid = norm_grid(base_grid + flow12) # BHW2 im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad) return im1_recons def get_occu_mask_bidirection(flow12, flow21, mask, scale=1, bias=0.5): flow21_warped = flow_warp(flow21, flow12, pad='zeros') flow12_diff = flow12 + flow21_warped mag = (flow12 * flow12).sum(1, keepdim=True) + \ (flow21_warped * flow21_warped).sum(1, keepdim=True) occ_thresh = scale * mag + bias occu = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh # soft_occu = 1.0 / (1 + torch.exp(diff) / 5.0) # print("forward:", diff.max(), diff.min()) return occu def get_occu_mask_backward(flow21, th=0.2): B, _, H, W = flow21.size() base_grid = mesh_grid(B, H, W).type_as(flow21) # B2HW corr_map = get_corresponding_map(base_grid + flow21) # BHW occu_mask = corr_map.clamp(min=0., max=1.) < th return occu_mask.float() def get_ssv_weights(cycle_corres, input, mask, scale_value): vgrid = (cycle_corres.mul(scale_value) - 1.0).permute(0,2,3,1) new_input = nn.functional.grid_sample(input, vgrid, align_corners=True, padding_mode='border') color_diff = (((input[:, :3, :, :] - new_input[:, :3, :, :]) / 255.0) ** 2).sum(1, keepdim=True) depth_diff = (((input[:, 3:, :, :] - new_input[:, 3:, :, :])) ** 2).sum(1, keepdim=True) diff = torch.mul(mask.float(), color_diff + depth_diff) #(N, 1, H, W) return torch.exp(-diff)