Spaces:
Runtime error
Runtime error
# Copyright Niantic 2021. Patent Pending. All rights reserved. | |
# | |
# This software is licensed under the terms of the ManyDepth licence | |
# which allows for non-commercial use only, the full terms of which are made | |
# available in the LICENSE file. | |
import os | |
os.environ["MKL_NUM_THREADS"] = "1" # noqa F402 | |
os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402 | |
os.environ["OMP_NUM_THREADS"] = "1" # noqa F402 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
import torch.utils.model_zoo as model_zoo | |
class BackprojectDepth(nn.Module): | |
"""Layer to transform a depth image into a point cloud | |
""" | |
def __init__(self, batch_size, height, width): | |
super(BackprojectDepth, self).__init__() | |
self.batch_size = batch_size | |
self.height = height | |
self.width = width | |
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') | |
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) | |
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), | |
requires_grad=False) | |
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), | |
requires_grad=False) | |
self.pix_coords = torch.unsqueeze(torch.stack( | |
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) | |
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) | |
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), | |
requires_grad=False) | |
def forward(self, depth, inv_K): | |
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) | |
cam_points = depth.view(self.batch_size, 1, -1) * cam_points | |
cam_points = torch.cat([cam_points, self.ones], 1) | |
return cam_points | |
class Project3D(nn.Module): | |
"""Layer which projects 3D points into a camera with intrinsics K and at position T | |
""" | |
def __init__(self, batch_size, height, width, eps=1e-7): | |
super(Project3D, self).__init__() | |
self.batch_size = batch_size | |
self.height = height | |
self.width = width | |
self.eps = eps | |
def forward(self, points, K, T): | |
P = torch.matmul(K, T)[:, :3, :] | |
cam_points = torch.matmul(P, points) | |
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) | |
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) | |
pix_coords = pix_coords.permute(0, 2, 3, 1) | |
pix_coords[..., 0] /= self.width - 1 | |
pix_coords[..., 1] /= self.height - 1 | |
pix_coords = (pix_coords - 0.5) * 2 | |
return pix_coords | |
class ResNetMultiImageInput(models.ResNet): | |
"""Constructs a resnet model with varying number of input images. | |
Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py | |
""" | |
def __init__(self, block, layers, num_classes=1000, num_input_images=1): | |
super(ResNetMultiImageInput, self).__init__(block, layers) | |
self.inplanes = 64 | |
self.conv1 = nn.Conv2d( | |
num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.relu = nn.ReLU(inplace=True) | |
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
self.layer1 = self._make_layer(block, 64, layers[0]) | |
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): | |
"""Constructs a ResNet model. | |
Args: | |
num_layers (int): Number of resnet layers. Must be 18 or 50 | |
pretrained (bool): If True, returns a model pre-trained on ImageNet | |
num_input_images (int): Number of frames stacked as input | |
""" | |
assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" | |
blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] | |
block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] | |
model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) | |
if pretrained: | |
loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) | |
loaded['conv1.weight'] = torch.cat( | |
[loaded['conv1.weight']] * num_input_images, 1) / num_input_images | |
model.load_state_dict(loaded) | |
return model | |
class ResnetEncoderMatching(nn.Module): | |
"""Resnet encoder adapted to include a cost volume after the 2nd block. | |
Setting adaptive_bins=True will recompute the depth bins used for matching upon each | |
forward pass - this is required for training from monocular video as there is an unknown scale. | |
""" | |
def __init__(self, num_layers, pretrained, input_height, input_width, | |
min_depth_bin=0.1, max_depth_bin=20.0, num_depth_bins=96, | |
adaptive_bins=False, depth_binning='linear'): | |
super(ResnetEncoderMatching, self).__init__() | |
self.adaptive_bins = adaptive_bins | |
self.depth_binning = depth_binning | |
self.set_missing_to_max = True | |
self.num_ch_enc = np.array([64, 64, 128, 256, 512]) | |
self.num_depth_bins = num_depth_bins | |
# we build the cost volume at 1/4 resolution | |
self.matching_height, self.matching_width = input_height // 4, input_width // 4 | |
self.is_cuda = False | |
self.warp_depths = None | |
self.depth_bins = None | |
resnets = {18: models.resnet18, | |
34: models.resnet34, | |
50: models.resnet50, | |
101: models.resnet101, | |
152: models.resnet152} | |
if num_layers not in resnets: | |
raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) | |
encoder = resnets[num_layers](pretrained) | |
self.layer0 = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) | |
self.layer1 = nn.Sequential(encoder.maxpool, encoder.layer1) | |
self.layer2 = encoder.layer2 | |
self.layer3 = encoder.layer3 | |
self.layer4 = encoder.layer4 | |
if num_layers > 34: | |
self.num_ch_enc[1:] *= 4 | |
self.backprojector = BackprojectDepth(batch_size=self.num_depth_bins, | |
height=self.matching_height, | |
width=self.matching_width) | |
self.projector = Project3D(batch_size=self.num_depth_bins, | |
height=self.matching_height, | |
width=self.matching_width) | |
self.compute_depth_bins(min_depth_bin, max_depth_bin) | |
self.prematching_conv = nn.Sequential(nn.Conv2d(64, out_channels=16, | |
kernel_size=1, stride=1, padding=0), | |
nn.ReLU(inplace=True) | |
) | |
self.reduce_conv = nn.Sequential(nn.Conv2d(self.num_ch_enc[1] + self.num_depth_bins, | |
out_channels=self.num_ch_enc[1], | |
kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True) | |
) | |
def compute_depth_bins(self, min_depth_bin, max_depth_bin): | |
"""Compute the depths bins used to build the cost volume. Bins will depend upon | |
self.depth_binning, to either be linear in depth (linear) or linear in inverse depth | |
(inverse)""" | |
if self.depth_binning == 'inverse': | |
self.depth_bins = 1 / np.linspace(1 / max_depth_bin, | |
1 / min_depth_bin, | |
self.num_depth_bins)[::-1] # maintain depth order | |
elif self.depth_binning == 'linear': | |
self.depth_bins = np.linspace(min_depth_bin, max_depth_bin, self.num_depth_bins) | |
else: | |
raise NotImplementedError | |
self.depth_bins = torch.from_numpy(self.depth_bins).float() | |
self.warp_depths = [] | |
for depth in self.depth_bins: | |
depth = torch.ones((1, self.matching_height, self.matching_width)) * depth | |
self.warp_depths.append(depth) | |
self.warp_depths = torch.stack(self.warp_depths, 0).float() | |
if self.is_cuda: | |
self.warp_depths = self.warp_depths.cuda() | |
def match_features(self, current_feats, lookup_feats, relative_poses, K, invK): | |
"""Compute a cost volume based on L1 difference between current_feats and lookup_feats. | |
We backwards warp the lookup_feats into the current frame using the estimated relative | |
pose, known intrinsics and using hypothesised depths self.warp_depths (which are either | |
linear in depth or linear in inverse depth). | |
If relative_pose == 0 then this indicates that the lookup frame is missing (i.e. we are | |
at the start of a sequence), and so we skip it""" | |
batch_cost_volume = [] # store all cost volumes of the batch | |
cost_volume_masks = [] # store locations of '0's in cost volume for confidence | |
for batch_idx in range(len(current_feats)): | |
volume_shape = (self.num_depth_bins, self.matching_height, self.matching_width) | |
cost_volume = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device) | |
counts = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device) | |
# select an item from batch of ref feats | |
_lookup_feats = lookup_feats[batch_idx:batch_idx + 1] | |
_lookup_poses = relative_poses[batch_idx:batch_idx + 1] | |
_K = K[batch_idx:batch_idx + 1] | |
_invK = invK[batch_idx:batch_idx + 1] | |
world_points = self.backprojector(self.warp_depths, _invK) | |
# loop through ref images adding to the current cost volume | |
for lookup_idx in range(_lookup_feats.shape[1]): | |
lookup_feat = _lookup_feats[:, lookup_idx] # 1 x C x H x W | |
lookup_pose = _lookup_poses[:, lookup_idx] | |
# ignore missing images | |
if lookup_pose.sum() == 0: | |
continue | |
lookup_feat = lookup_feat.repeat([self.num_depth_bins, 1, 1, 1]) | |
pix_locs = self.projector(world_points, _K, lookup_pose) | |
warped = F.grid_sample(lookup_feat, pix_locs, padding_mode='zeros', mode='bilinear', | |
align_corners=True) | |
# mask values landing outside the image (and near the border) | |
# we want to ignore edge pixels of the lookup images and the current image | |
# because of zero padding in ResNet | |
# Masking of ref image border | |
x_vals = (pix_locs[..., 0].detach() / 2 + 0.5) * ( | |
self.matching_width - 1) # convert from (-1, 1) to pixel values | |
y_vals = (pix_locs[..., 1].detach() / 2 + 0.5) * (self.matching_height - 1) | |
edge_mask = (x_vals >= 2.0) * (x_vals <= self.matching_width - 2) * \ | |
(y_vals >= 2.0) * (y_vals <= self.matching_height - 2) | |
edge_mask = edge_mask.float() | |
# masking of current image | |
current_mask = torch.zeros_like(edge_mask) | |
current_mask[:, 2:-2, 2:-2] = 1.0 | |
edge_mask = edge_mask * current_mask | |
diffs = torch.abs(warped - current_feats[batch_idx:batch_idx + 1]).mean( | |
1) * edge_mask | |
# integrate into cost volume | |
cost_volume = cost_volume + diffs | |
counts = counts + (diffs > 0).float() | |
# average over lookup images | |
cost_volume = cost_volume / (counts + 1e-7) | |
# if some missing values for a pixel location (i.e. some depths landed outside) then | |
# set to max of existing values | |
missing_val_mask = (cost_volume == 0).float() | |
if self.set_missing_to_max: | |
cost_volume = cost_volume * (1 - missing_val_mask) + \ | |
cost_volume.max(0)[0].unsqueeze(0) * missing_val_mask | |
batch_cost_volume.append(cost_volume) | |
cost_volume_masks.append(missing_val_mask) | |
batch_cost_volume = torch.stack(batch_cost_volume, 0) | |
cost_volume_masks = torch.stack(cost_volume_masks, 0) | |
return batch_cost_volume, cost_volume_masks | |
def feature_extraction(self, image, return_all_feats=False): | |
""" Run feature extraction on an image - first 2 blocks of ResNet""" | |
image = (image - 0.45) / 0.225 # imagenet normalisation | |
feats_0 = self.layer0(image) | |
feats_1 = self.layer1(feats_0) | |
if return_all_feats: | |
return [feats_0, feats_1] | |
else: | |
return feats_1 | |
def indices_to_disparity(self, indices): | |
"""Convert cost volume indices to 1/depth for visualisation""" | |
batch, height, width = indices.shape | |
depth = self.depth_bins[indices.reshape(-1).cpu()] | |
disp = 1 / depth.reshape((batch, height, width)) | |
return disp | |
def compute_confidence_mask(self, cost_volume, num_bins_threshold=None): | |
""" Returns a 'confidence' mask based on how many times a depth bin was observed""" | |
if num_bins_threshold is None: | |
num_bins_threshold = self.num_depth_bins | |
confidence_mask = ((cost_volume > 0).sum(1) == num_bins_threshold).float() | |
return confidence_mask | |
def forward(self, current_image, lookup_images, poses, K, invK, | |
min_depth_bin=None, max_depth_bin=None | |
): | |
# feature extraction | |
self.features = self.feature_extraction(current_image, return_all_feats=True) | |
current_feats = self.features[-1] | |
# print('current_feats:', current_feats.shape) | |
# feature extraction on lookup images - disable gradients to save memory | |
with torch.no_grad(): | |
if self.adaptive_bins: | |
self.compute_depth_bins(min_depth_bin, max_depth_bin) | |
batch_size, num_frames, chns, height, width = lookup_images.shape | |
lookup_images = lookup_images.reshape(batch_size * num_frames, chns, height, width) | |
lookup_feats = self.feature_extraction(lookup_images, | |
return_all_feats=False) | |
_, chns, height, width = lookup_feats.shape | |
lookup_feats = lookup_feats.reshape(batch_size, num_frames, chns, height, width) | |
# print('lookup_feats:', lookup_feats.shape) | |
# warp features to find cost volume | |
cost_volume, missing_mask = \ | |
self.match_features(current_feats, lookup_feats, poses, K, invK) | |
confidence_mask = self.compute_confidence_mask(cost_volume.detach() * | |
(1 - missing_mask.detach())) | |
# for visualisation - ignore 0s in cost volume for minimum | |
viz_cost_vol = cost_volume.clone().detach() | |
viz_cost_vol[viz_cost_vol == 0] = 100 | |
mins, argmin = torch.min(viz_cost_vol, 1) | |
lowest_cost = self.indices_to_disparity(argmin) | |
# mask the cost volume based on the confidence | |
cost_volume *= confidence_mask.unsqueeze(1) | |
post_matching_feats = self.reduce_conv(torch.cat([self.features[-1], cost_volume], 1)) | |
# print('post_matching_feats:', post_matching_feats.shape) | |
self.features.append(self.layer2(post_matching_feats)) | |
self.features.append(self.layer3(self.features[-1])) | |
self.features.append(self.layer4(self.features[-1])) | |
return self.features, lowest_cost, confidence_mask | |
def cuda(self): | |
super().cuda() | |
self.backprojector.cuda() | |
self.projector.cuda() | |
self.is_cuda = True | |
if self.warp_depths is not None: | |
self.warp_depths = self.warp_depths.cuda() | |
def cpu(self): | |
super().cpu() | |
self.backprojector.cpu() | |
self.projector.cpu() | |
self.is_cuda = False | |
if self.warp_depths is not None: | |
self.warp_depths = self.warp_depths.cpu() | |
def to(self, device): | |
if str(device) == 'cpu': | |
self.cpu() | |
elif str(device) == 'cuda': | |
self.cuda() | |
else: | |
raise NotImplementedError | |
class ResnetEncoder(nn.Module): | |
"""Pytorch module for a resnet encoder | |
""" | |
def __init__(self, num_layers, pretrained, num_input_images=1, **kwargs): | |
super(ResnetEncoder, self).__init__() | |
self.num_ch_enc = np.array([64, 64, 128, 256, 512]) | |
resnets = {18: models.resnet18, | |
34: models.resnet34, | |
50: models.resnet50, | |
101: models.resnet101, | |
152: models.resnet152} | |
if num_layers not in resnets: | |
raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) | |
if num_input_images > 1: | |
self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) | |
else: | |
self.encoder = resnets[num_layers](pretrained) | |
if num_layers > 34: | |
self.num_ch_enc[1:] *= 4 | |
def forward(self, input_image): | |
self.features = [] | |
x = (input_image - 0.45) / 0.225 | |
x = self.encoder.conv1(x) | |
x = self.encoder.bn1(x) | |
self.features.append(self.encoder.relu(x)) | |
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) | |
self.features.append(self.encoder.layer2(self.features[-1])) | |
self.features.append(self.encoder.layer3(self.features[-1])) | |
self.features.append(self.encoder.layer4(self.features[-1])) | |
return self.features | |