yinwentao
DockerFile
8d34f50
import sys, os
import numpy as np
import torch
def make_conv(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
blocks = []
for i in range(n_blocks):
in1 = n_in if i == 0 else n_out
blocks.append(torch.nn.Sequential(
torch.nn.Conv3d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
normalization(n_out),
activation(inplace=True)
))
return torch.nn.Sequential(*blocks)
def make_conv_2d(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
blocks = []
for i in range(n_blocks):
in1 = n_in if i == 0 else n_out
blocks.append(torch.nn.Sequential(
torch.nn.Conv2d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
normalization(n_out),
activation(inplace=True)
))
return torch.nn.Sequential(*blocks)
def make_downscale(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
block = torch.nn.Sequential(
torch.nn.Conv3d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
normalization(n_out),
activation(inplace=True)
)
return block
def make_downscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
block = torch.nn.Sequential(
torch.nn.Conv2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
normalization(n_out),
activation(inplace=True)
)
return block
def make_upscale(n_in, n_out, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
block = torch.nn.Sequential(
torch.nn.ConvTranspose3d(n_in, n_out, kernel_size=6, stride=2, padding=2),
normalization(n_out),
activation(inplace=True)
)
return block
def make_upscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
block = torch.nn.Sequential(
torch.nn.ConvTranspose2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
normalization(n_out),
activation(inplace=True)
)
return block
class ResBlock(torch.nn.Module):
def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
super().__init__()
self.block0 = torch.nn.Sequential(
torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
normalization(n_out),
activation(inplace=True)
)
self.block1 = torch.nn.Sequential(
torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
normalization(n_out),
)
self.block2 = torch.nn.ReLU()
def forward(self, x0):
x = self.block0(x0)
x = self.block1(x)
x = self.block2(x + x0)
return x
class ResBlock2d(torch.nn.Module):
def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
super().__init__()
self.block0 = torch.nn.Sequential(
torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
normalization(n_out),
activation(inplace=True)
)
self.block1 = torch.nn.Sequential(
torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
normalization(n_out),
)
self.block2 = torch.nn.ReLU()
def forward(self, x0):
x = self.block0(x0)
x = self.block1(x)
x = self.block2(x + x0)
return x
class Identity(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
return x
def downscale_gt_flow(flow_gt, flow_mask, image_height, image_width):
flow_gt_copy = flow_gt.clone()
flow_mask_copy = flow_mask.clone()
flow_gt_copy = flow_gt_copy / 20.0
flow_mask_copy = flow_mask_copy.float()
assert image_height % 64 == 0 and image_width % 64 == 0
flow_gt2 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//4, image_width//4), mode='nearest')
flow_mask2 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//4, image_width//4), mode='nearest').bool()
flow_gt3 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//8, image_width//8), mode='nearest')
flow_mask3 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//8, image_width//8), mode='nearest').bool()
flow_gt4 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//16, image_width//16), mode='nearest')
flow_mask4 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//16, image_width//16), mode='nearest').bool()
flow_gt5 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//32, image_width//32), mode='nearest')
flow_mask5 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//32, image_width//32), mode='nearest').bool()
flow_gt6 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//64, image_width//64), mode='nearest')
flow_mask6 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//64, image_width//64), mode='nearest').bool()
return [flow_gt2, flow_gt3, flow_gt4, flow_gt5, flow_gt6], [flow_mask2, flow_mask3, flow_mask4, flow_mask5, flow_mask6]
def compute_baseline_mask_gt(
xy_coords_warped,
target_matches, valid_target_matches,
source_points, valid_source_points,
scene_flow_gt, scene_flow_mask, target_boundary_mask,
max_pos_flowed_source_to_target_dist, min_neg_flowed_source_to_target_dist
):
# Scene flow mask
scene_flow_mask_0 = scene_flow_mask[:, 0].type(torch.bool)
# Boundary correspondences mask
# We use the nearest neighbor interpolation, since the boundary computations
# already marks any of 4 pixels as boundary.
target_nonboundary_mask = (~target_boundary_mask).type(torch.float32)
target_matches_nonboundary_mask = torch.nn.functional.grid_sample(target_nonboundary_mask, xy_coords_warped, padding_mode='zeros', mode='nearest', align_corners=False)
target_matches_nonboundary_mask = target_matches_nonboundary_mask[:, 0, :, :] >= 0.999
# Compute groundtruth mask (oracle)
flowed_source_points = source_points + scene_flow_gt
dist = torch.norm(flowed_source_points - target_matches, p=2, dim=1)
# Combine all masks
# We mark a correspondence as positive if;
# - it is close enough to groundtruth flow
# AND
# - there exists groundtruth flow
# AND
# - the target match is valid
# AND
# - the source point is valid
# AND
# - the target match is not on the boundary
mask_pos_gt = (dist <= max_pos_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_target_matches & valid_source_points & target_matches_nonboundary_mask
# We mark a correspondence as negative if;
# - there exists groundtruth flow AND it is far away enough from the groundtruth flow AND source/target points are valid
# OR
# - the target match is on the boundary AND there exists groundtruth flow AND source/target points are valid
mask_neg_gt = ((dist > min_neg_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_source_points & valid_target_matches) \
| (~target_matches_nonboundary_mask & scene_flow_mask_0 & valid_source_points & valid_target_matches)
# What remains is left undecided (masked out at loss).
# For groundtruth mask we set it to zero.
valid_mask_pixels = mask_pos_gt | mask_neg_gt
mask_gt = mask_pos_gt
mask_gt = mask_gt.type(torch.float32)
return mask_gt, valid_mask_pixels
def compute_deformed_points_gt(
source_points, scene_flow_gt,
valid_solve, valid_correspondences,
deformed_points_idxs, deformed_points_subsampled
):
batch_size = source_points.shape[0]
max_warped_points = deformed_points_idxs.shape[1]
deformed_points_gt = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device)
deformed_points_mask = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device)
for i in range(batch_size):
if valid_solve[i]:
valid_correspondences_idxs = torch.where(valid_correspondences[i])
# Compute deformed point groundtruth.
deformed_points_i_gt = source_points[i] + scene_flow_gt[i]
deformed_points_i_gt = deformed_points_i_gt.permute(1, 2, 0)
deformed_points_i_gt = deformed_points_i_gt[valid_correspondences_idxs[0], valid_correspondences_idxs[1], :].view(-1, 3, 1)
# Filter out points randomly, if too many are still left.
if deformed_points_subsampled[i]:
sampled_idxs_i = deformed_points_idxs[i]
deformed_points_i_gt = deformed_points_i_gt[sampled_idxs_i]
num_points = deformed_points_i_gt.shape[0]
# Store the results.
deformed_points_gt[i, :num_points, :] = deformed_points_i_gt.view(1, num_points, 3)
deformed_points_mask[i, :num_points, :] = 1
return deformed_points_gt, deformed_points_mask