Yiting1009's picture
Upload 26 files
5d87992
# Copyright Niantic 2021. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the ManyDepth licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
import os
os.environ["MKL_NUM_THREADS"] = "1" # noqa F402
os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402
os.environ["OMP_NUM_THREADS"] = "1" # noqa F402
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
class BackprojectDepth(nn.Module):
"""Layer to transform a depth image into a point cloud
"""
def __init__(self, batch_size, height, width):
super(BackprojectDepth, self).__init__()
self.batch_size = batch_size
self.height = height
self.width = width
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
requires_grad=False)
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
requires_grad=False)
self.pix_coords = torch.unsqueeze(torch.stack(
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
requires_grad=False)
def forward(self, depth, inv_K):
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
cam_points = depth.view(self.batch_size, 1, -1) * cam_points
cam_points = torch.cat([cam_points, self.ones], 1)
return cam_points
class Project3D(nn.Module):
"""Layer which projects 3D points into a camera with intrinsics K and at position T
"""
def __init__(self, batch_size, height, width, eps=1e-7):
super(Project3D, self).__init__()
self.batch_size = batch_size
self.height = height
self.width = width
self.eps = eps
def forward(self, points, K, T):
P = torch.matmul(K, T)[:, :3, :]
cam_points = torch.matmul(P, points)
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
pix_coords = pix_coords.permute(0, 2, 3, 1)
pix_coords[..., 0] /= self.width - 1
pix_coords[..., 1] /= self.height - 1
pix_coords = (pix_coords - 0.5) * 2
return pix_coords
class ResNetMultiImageInput(models.ResNet):
"""Constructs a resnet model with varying number of input images.
Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""
def __init__(self, block, layers, num_classes=1000, num_input_images=1):
super(ResNetMultiImageInput, self).__init__(block, layers)
self.inplanes = 64
self.conv1 = nn.Conv2d(
num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
"""Constructs a ResNet model.
Args:
num_layers (int): Number of resnet layers. Must be 18 or 50
pretrained (bool): If True, returns a model pre-trained on ImageNet
num_input_images (int): Number of frames stacked as input
"""
assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
if pretrained:
loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
loaded['conv1.weight'] = torch.cat(
[loaded['conv1.weight']] * num_input_images, 1) / num_input_images
model.load_state_dict(loaded)
return model
class ResnetEncoderMatching(nn.Module):
"""Resnet encoder adapted to include a cost volume after the 2nd block.
Setting adaptive_bins=True will recompute the depth bins used for matching upon each
forward pass - this is required for training from monocular video as there is an unknown scale.
"""
def __init__(self, num_layers, pretrained, input_height, input_width,
min_depth_bin=0.1, max_depth_bin=20.0, num_depth_bins=96,
adaptive_bins=False, depth_binning='linear'):
super(ResnetEncoderMatching, self).__init__()
self.adaptive_bins = adaptive_bins
self.depth_binning = depth_binning
self.set_missing_to_max = True
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
self.num_depth_bins = num_depth_bins
# we build the cost volume at 1/4 resolution
self.matching_height, self.matching_width = input_height // 4, input_width // 4
self.is_cuda = False
self.warp_depths = None
self.depth_bins = None
resnets = {18: models.resnet18,
34: models.resnet34,
50: models.resnet50,
101: models.resnet101,
152: models.resnet152}
if num_layers not in resnets:
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
encoder = resnets[num_layers](pretrained)
self.layer0 = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
self.layer1 = nn.Sequential(encoder.maxpool, encoder.layer1)
self.layer2 = encoder.layer2
self.layer3 = encoder.layer3
self.layer4 = encoder.layer4
if num_layers > 34:
self.num_ch_enc[1:] *= 4
self.backprojector = BackprojectDepth(batch_size=self.num_depth_bins,
height=self.matching_height,
width=self.matching_width)
self.projector = Project3D(batch_size=self.num_depth_bins,
height=self.matching_height,
width=self.matching_width)
self.compute_depth_bins(min_depth_bin, max_depth_bin)
self.prematching_conv = nn.Sequential(nn.Conv2d(64, out_channels=16,
kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True)
)
self.reduce_conv = nn.Sequential(nn.Conv2d(self.num_ch_enc[1] + self.num_depth_bins,
out_channels=self.num_ch_enc[1],
kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True)
)
def compute_depth_bins(self, min_depth_bin, max_depth_bin):
"""Compute the depths bins used to build the cost volume. Bins will depend upon
self.depth_binning, to either be linear in depth (linear) or linear in inverse depth
(inverse)"""
if self.depth_binning == 'inverse':
self.depth_bins = 1 / np.linspace(1 / max_depth_bin,
1 / min_depth_bin,
self.num_depth_bins)[::-1] # maintain depth order
elif self.depth_binning == 'linear':
self.depth_bins = np.linspace(min_depth_bin, max_depth_bin, self.num_depth_bins)
else:
raise NotImplementedError
self.depth_bins = torch.from_numpy(self.depth_bins).float()
self.warp_depths = []
for depth in self.depth_bins:
depth = torch.ones((1, self.matching_height, self.matching_width)) * depth
self.warp_depths.append(depth)
self.warp_depths = torch.stack(self.warp_depths, 0).float()
if self.is_cuda:
self.warp_depths = self.warp_depths.cuda()
def match_features(self, current_feats, lookup_feats, relative_poses, K, invK):
"""Compute a cost volume based on L1 difference between current_feats and lookup_feats.
We backwards warp the lookup_feats into the current frame using the estimated relative
pose, known intrinsics and using hypothesised depths self.warp_depths (which are either
linear in depth or linear in inverse depth).
If relative_pose == 0 then this indicates that the lookup frame is missing (i.e. we are
at the start of a sequence), and so we skip it"""
batch_cost_volume = [] # store all cost volumes of the batch
cost_volume_masks = [] # store locations of '0's in cost volume for confidence
for batch_idx in range(len(current_feats)):
volume_shape = (self.num_depth_bins, self.matching_height, self.matching_width)
cost_volume = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device)
counts = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device)
# select an item from batch of ref feats
_lookup_feats = lookup_feats[batch_idx:batch_idx + 1]
_lookup_poses = relative_poses[batch_idx:batch_idx + 1]
_K = K[batch_idx:batch_idx + 1]
_invK = invK[batch_idx:batch_idx + 1]
world_points = self.backprojector(self.warp_depths, _invK)
# loop through ref images adding to the current cost volume
for lookup_idx in range(_lookup_feats.shape[1]):
lookup_feat = _lookup_feats[:, lookup_idx] # 1 x C x H x W
lookup_pose = _lookup_poses[:, lookup_idx]
# ignore missing images
if lookup_pose.sum() == 0:
continue
lookup_feat = lookup_feat.repeat([self.num_depth_bins, 1, 1, 1])
pix_locs = self.projector(world_points, _K, lookup_pose)
warped = F.grid_sample(lookup_feat, pix_locs, padding_mode='zeros', mode='bilinear',
align_corners=True)
# mask values landing outside the image (and near the border)
# we want to ignore edge pixels of the lookup images and the current image
# because of zero padding in ResNet
# Masking of ref image border
x_vals = (pix_locs[..., 0].detach() / 2 + 0.5) * (
self.matching_width - 1) # convert from (-1, 1) to pixel values
y_vals = (pix_locs[..., 1].detach() / 2 + 0.5) * (self.matching_height - 1)
edge_mask = (x_vals >= 2.0) * (x_vals <= self.matching_width - 2) * \
(y_vals >= 2.0) * (y_vals <= self.matching_height - 2)
edge_mask = edge_mask.float()
# masking of current image
current_mask = torch.zeros_like(edge_mask)
current_mask[:, 2:-2, 2:-2] = 1.0
edge_mask = edge_mask * current_mask
diffs = torch.abs(warped - current_feats[batch_idx:batch_idx + 1]).mean(
1) * edge_mask
# integrate into cost volume
cost_volume = cost_volume + diffs
counts = counts + (diffs > 0).float()
# average over lookup images
cost_volume = cost_volume / (counts + 1e-7)
# if some missing values for a pixel location (i.e. some depths landed outside) then
# set to max of existing values
missing_val_mask = (cost_volume == 0).float()
if self.set_missing_to_max:
cost_volume = cost_volume * (1 - missing_val_mask) + \
cost_volume.max(0)[0].unsqueeze(0) * missing_val_mask
batch_cost_volume.append(cost_volume)
cost_volume_masks.append(missing_val_mask)
batch_cost_volume = torch.stack(batch_cost_volume, 0)
cost_volume_masks = torch.stack(cost_volume_masks, 0)
return batch_cost_volume, cost_volume_masks
def feature_extraction(self, image, return_all_feats=False):
""" Run feature extraction on an image - first 2 blocks of ResNet"""
image = (image - 0.45) / 0.225 # imagenet normalisation
feats_0 = self.layer0(image)
feats_1 = self.layer1(feats_0)
if return_all_feats:
return [feats_0, feats_1]
else:
return feats_1
def indices_to_disparity(self, indices):
"""Convert cost volume indices to 1/depth for visualisation"""
batch, height, width = indices.shape
depth = self.depth_bins[indices.reshape(-1).cpu()]
disp = 1 / depth.reshape((batch, height, width))
return disp
def compute_confidence_mask(self, cost_volume, num_bins_threshold=None):
""" Returns a 'confidence' mask based on how many times a depth bin was observed"""
if num_bins_threshold is None:
num_bins_threshold = self.num_depth_bins
confidence_mask = ((cost_volume > 0).sum(1) == num_bins_threshold).float()
return confidence_mask
def forward(self, current_image, lookup_images, poses, K, invK,
min_depth_bin=None, max_depth_bin=None
):
# feature extraction
self.features = self.feature_extraction(current_image, return_all_feats=True)
current_feats = self.features[-1]
# print('current_feats:', current_feats.shape)
# feature extraction on lookup images - disable gradients to save memory
with torch.no_grad():
if self.adaptive_bins:
self.compute_depth_bins(min_depth_bin, max_depth_bin)
batch_size, num_frames, chns, height, width = lookup_images.shape
lookup_images = lookup_images.reshape(batch_size * num_frames, chns, height, width)
lookup_feats = self.feature_extraction(lookup_images,
return_all_feats=False)
_, chns, height, width = lookup_feats.shape
lookup_feats = lookup_feats.reshape(batch_size, num_frames, chns, height, width)
# print('lookup_feats:', lookup_feats.shape)
# warp features to find cost volume
cost_volume, missing_mask = \
self.match_features(current_feats, lookup_feats, poses, K, invK)
confidence_mask = self.compute_confidence_mask(cost_volume.detach() *
(1 - missing_mask.detach()))
# for visualisation - ignore 0s in cost volume for minimum
viz_cost_vol = cost_volume.clone().detach()
viz_cost_vol[viz_cost_vol == 0] = 100
mins, argmin = torch.min(viz_cost_vol, 1)
lowest_cost = self.indices_to_disparity(argmin)
# mask the cost volume based on the confidence
cost_volume *= confidence_mask.unsqueeze(1)
post_matching_feats = self.reduce_conv(torch.cat([self.features[-1], cost_volume], 1))
# print('post_matching_feats:', post_matching_feats.shape)
self.features.append(self.layer2(post_matching_feats))
self.features.append(self.layer3(self.features[-1]))
self.features.append(self.layer4(self.features[-1]))
return self.features, lowest_cost, confidence_mask
def cuda(self):
super().cuda()
self.backprojector.cuda()
self.projector.cuda()
self.is_cuda = True
if self.warp_depths is not None:
self.warp_depths = self.warp_depths.cuda()
def cpu(self):
super().cpu()
self.backprojector.cpu()
self.projector.cpu()
self.is_cuda = False
if self.warp_depths is not None:
self.warp_depths = self.warp_depths.cpu()
def to(self, device):
if str(device) == 'cpu':
self.cpu()
elif str(device) == 'cuda':
self.cuda()
else:
raise NotImplementedError
class ResnetEncoder(nn.Module):
"""Pytorch module for a resnet encoder
"""
def __init__(self, num_layers, pretrained, num_input_images=1, **kwargs):
super(ResnetEncoder, self).__init__()
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
resnets = {18: models.resnet18,
34: models.resnet34,
50: models.resnet50,
101: models.resnet101,
152: models.resnet152}
if num_layers not in resnets:
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
if num_input_images > 1:
self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
else:
self.encoder = resnets[num_layers](pretrained)
if num_layers > 34:
self.num_ch_enc[1:] *= 4
def forward(self, input_image):
self.features = []
x = (input_image - 0.45) / 0.225
x = self.encoder.conv1(x)
x = self.encoder.bn1(x)
self.features.append(self.encoder.relu(x))
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
self.features.append(self.encoder.layer2(self.features[-1]))
self.features.append(self.encoder.layer3(self.features[-1]))
self.features.append(self.encoder.layer4(self.features[-1]))
return self.features