|
import sys, os |
|
import numpy as np |
|
import torch |
|
|
|
|
|
def make_conv(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU): |
|
blocks = [] |
|
for i in range(n_blocks): |
|
in1 = n_in if i == 0 else n_out |
|
blocks.append(torch.nn.Sequential( |
|
torch.nn.Conv3d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)), |
|
normalization(n_out), |
|
activation(inplace=True) |
|
)) |
|
return torch.nn.Sequential(*blocks) |
|
|
|
|
|
def make_conv_2d(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU): |
|
blocks = [] |
|
for i in range(n_blocks): |
|
in1 = n_in if i == 0 else n_out |
|
blocks.append(torch.nn.Sequential( |
|
torch.nn.Conv2d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)), |
|
normalization(n_out), |
|
activation(inplace=True) |
|
)) |
|
return torch.nn.Sequential(*blocks) |
|
|
|
|
|
def make_downscale(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU): |
|
block = torch.nn.Sequential( |
|
torch.nn.Conv3d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2), |
|
normalization(n_out), |
|
activation(inplace=True) |
|
) |
|
return block |
|
|
|
|
|
def make_downscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU): |
|
block = torch.nn.Sequential( |
|
torch.nn.Conv2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2), |
|
normalization(n_out), |
|
activation(inplace=True) |
|
) |
|
return block |
|
|
|
|
|
def make_upscale(n_in, n_out, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU): |
|
block = torch.nn.Sequential( |
|
torch.nn.ConvTranspose3d(n_in, n_out, kernel_size=6, stride=2, padding=2), |
|
normalization(n_out), |
|
activation(inplace=True) |
|
) |
|
return block |
|
|
|
|
|
def make_upscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU): |
|
block = torch.nn.Sequential( |
|
torch.nn.ConvTranspose2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2), |
|
normalization(n_out), |
|
activation(inplace=True) |
|
) |
|
return block |
|
|
|
|
|
class ResBlock(torch.nn.Module): |
|
def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU): |
|
super().__init__() |
|
self.block0 = torch.nn.Sequential( |
|
torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)), |
|
normalization(n_out), |
|
activation(inplace=True) |
|
) |
|
|
|
self.block1 = torch.nn.Sequential( |
|
torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)), |
|
normalization(n_out), |
|
) |
|
|
|
self.block2 = torch.nn.ReLU() |
|
|
|
def forward(self, x0): |
|
x = self.block0(x0) |
|
|
|
x = self.block1(x) |
|
|
|
x = self.block2(x + x0) |
|
return x |
|
|
|
|
|
class ResBlock2d(torch.nn.Module): |
|
def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU): |
|
super().__init__() |
|
self.block0 = torch.nn.Sequential( |
|
torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)), |
|
normalization(n_out), |
|
activation(inplace=True) |
|
) |
|
|
|
self.block1 = torch.nn.Sequential( |
|
torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)), |
|
normalization(n_out), |
|
) |
|
|
|
self.block2 = torch.nn.ReLU() |
|
|
|
def forward(self, x0): |
|
x = self.block0(x0) |
|
|
|
x = self.block1(x) |
|
|
|
x = self.block2(x + x0) |
|
return x |
|
|
|
|
|
class Identity(torch.nn.Module): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
|
|
def downscale_gt_flow(flow_gt, flow_mask, image_height, image_width): |
|
flow_gt_copy = flow_gt.clone() |
|
flow_mask_copy = flow_mask.clone() |
|
|
|
flow_gt_copy = flow_gt_copy / 20.0 |
|
flow_mask_copy = flow_mask_copy.float() |
|
|
|
assert image_height % 64 == 0 and image_width % 64 == 0 |
|
|
|
flow_gt2 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//4, image_width//4), mode='nearest') |
|
flow_mask2 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//4, image_width//4), mode='nearest').bool() |
|
|
|
flow_gt3 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//8, image_width//8), mode='nearest') |
|
flow_mask3 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//8, image_width//8), mode='nearest').bool() |
|
|
|
flow_gt4 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//16, image_width//16), mode='nearest') |
|
flow_mask4 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//16, image_width//16), mode='nearest').bool() |
|
|
|
flow_gt5 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//32, image_width//32), mode='nearest') |
|
flow_mask5 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//32, image_width//32), mode='nearest').bool() |
|
|
|
flow_gt6 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//64, image_width//64), mode='nearest') |
|
flow_mask6 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//64, image_width//64), mode='nearest').bool() |
|
|
|
return [flow_gt2, flow_gt3, flow_gt4, flow_gt5, flow_gt6], [flow_mask2, flow_mask3, flow_mask4, flow_mask5, flow_mask6] |
|
|
|
|
|
def compute_baseline_mask_gt( |
|
xy_coords_warped, |
|
target_matches, valid_target_matches, |
|
source_points, valid_source_points, |
|
scene_flow_gt, scene_flow_mask, target_boundary_mask, |
|
max_pos_flowed_source_to_target_dist, min_neg_flowed_source_to_target_dist |
|
): |
|
|
|
scene_flow_mask_0 = scene_flow_mask[:, 0].type(torch.bool) |
|
|
|
|
|
|
|
|
|
target_nonboundary_mask = (~target_boundary_mask).type(torch.float32) |
|
target_matches_nonboundary_mask = torch.nn.functional.grid_sample(target_nonboundary_mask, xy_coords_warped, padding_mode='zeros', mode='nearest', align_corners=False) |
|
target_matches_nonboundary_mask = target_matches_nonboundary_mask[:, 0, :, :] >= 0.999 |
|
|
|
|
|
flowed_source_points = source_points + scene_flow_gt |
|
dist = torch.norm(flowed_source_points - target_matches, p=2, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_pos_gt = (dist <= max_pos_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_target_matches & valid_source_points & target_matches_nonboundary_mask |
|
|
|
|
|
|
|
|
|
|
|
mask_neg_gt = ((dist > min_neg_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_source_points & valid_target_matches) \ |
|
| (~target_matches_nonboundary_mask & scene_flow_mask_0 & valid_source_points & valid_target_matches) |
|
|
|
|
|
|
|
valid_mask_pixels = mask_pos_gt | mask_neg_gt |
|
mask_gt = mask_pos_gt |
|
|
|
mask_gt = mask_gt.type(torch.float32) |
|
|
|
return mask_gt, valid_mask_pixels |
|
|
|
|
|
def compute_deformed_points_gt( |
|
source_points, scene_flow_gt, |
|
valid_solve, valid_correspondences, |
|
deformed_points_idxs, deformed_points_subsampled |
|
): |
|
batch_size = source_points.shape[0] |
|
max_warped_points = deformed_points_idxs.shape[1] |
|
|
|
deformed_points_gt = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device) |
|
deformed_points_mask = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device) |
|
|
|
for i in range(batch_size): |
|
if valid_solve[i]: |
|
valid_correspondences_idxs = torch.where(valid_correspondences[i]) |
|
|
|
|
|
deformed_points_i_gt = source_points[i] + scene_flow_gt[i] |
|
deformed_points_i_gt = deformed_points_i_gt.permute(1, 2, 0) |
|
deformed_points_i_gt = deformed_points_i_gt[valid_correspondences_idxs[0], valid_correspondences_idxs[1], :].view(-1, 3, 1) |
|
|
|
|
|
if deformed_points_subsampled[i]: |
|
sampled_idxs_i = deformed_points_idxs[i] |
|
deformed_points_i_gt = deformed_points_i_gt[sampled_idxs_i] |
|
|
|
num_points = deformed_points_i_gt.shape[0] |
|
|
|
|
|
deformed_points_gt[i, :num_points, :] = deformed_points_i_gt.view(1, num_points, 3) |
|
deformed_points_mask[i, :num_points, :] = 1 |
|
|
|
return deformed_points_gt, deformed_points_mask |
|
|