# Copyright Niantic 2021. Patent Pending. All rights reserved. # # This software is licensed under the terms of the ManyDepth licence # which allows for non-commercial use only, the full terms of which are made # available in the LICENSE file. import os os.environ["MKL_NUM_THREADS"] = "1" # noqa F402 os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa F402 os.environ["OMP_NUM_THREADS"] = "1" # noqa F402 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import torch.utils.model_zoo as model_zoo class BackprojectDepth(nn.Module): """Layer to transform a depth image into a point cloud """ def __init__(self, batch_size, height, width): super(BackprojectDepth, self).__init__() self.batch_size = batch_size self.height = height self.width = width meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), requires_grad=False) self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), requires_grad=False) self.pix_coords = torch.unsqueeze(torch.stack( [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), requires_grad=False) def forward(self, depth, inv_K): cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) cam_points = depth.view(self.batch_size, 1, -1) * cam_points cam_points = torch.cat([cam_points, self.ones], 1) return cam_points class Project3D(nn.Module): """Layer which projects 3D points into a camera with intrinsics K and at position T """ def __init__(self, batch_size, height, width, eps=1e-7): super(Project3D, self).__init__() self.batch_size = batch_size self.height = height self.width = width self.eps = eps def forward(self, points, K, T): P = torch.matmul(K, T)[:, :3, :] cam_points = torch.matmul(P, points) pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) pix_coords = pix_coords.permute(0, 2, 3, 1) pix_coords[..., 0] /= self.width - 1 pix_coords[..., 1] /= self.height - 1 pix_coords = (pix_coords - 0.5) * 2 return pix_coords class ResNetMultiImageInput(models.ResNet): """Constructs a resnet model with varying number of input images. Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py """ def __init__(self, block, layers, num_classes=1000, num_input_images=1): super(ResNetMultiImageInput, self).__init__(block, layers) self.inplanes = 64 self.conv1 = nn.Conv2d( num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): """Constructs a ResNet model. Args: num_layers (int): Number of resnet layers. Must be 18 or 50 pretrained (bool): If True, returns a model pre-trained on ImageNet num_input_images (int): Number of frames stacked as input """ assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) if pretrained: loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) loaded['conv1.weight'] = torch.cat( [loaded['conv1.weight']] * num_input_images, 1) / num_input_images model.load_state_dict(loaded) return model class ResnetEncoderMatching(nn.Module): """Resnet encoder adapted to include a cost volume after the 2nd block. Setting adaptive_bins=True will recompute the depth bins used for matching upon each forward pass - this is required for training from monocular video as there is an unknown scale. """ def __init__(self, num_layers, pretrained, input_height, input_width, min_depth_bin=0.1, max_depth_bin=20.0, num_depth_bins=96, adaptive_bins=False, depth_binning='linear'): super(ResnetEncoderMatching, self).__init__() self.adaptive_bins = adaptive_bins self.depth_binning = depth_binning self.set_missing_to_max = True self.num_ch_enc = np.array([64, 64, 128, 256, 512]) self.num_depth_bins = num_depth_bins # we build the cost volume at 1/4 resolution self.matching_height, self.matching_width = input_height // 4, input_width // 4 self.is_cuda = False self.warp_depths = None self.depth_bins = None resnets = {18: models.resnet18, 34: models.resnet34, 50: models.resnet50, 101: models.resnet101, 152: models.resnet152} if num_layers not in resnets: raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) encoder = resnets[num_layers](pretrained) self.layer0 = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) self.layer1 = nn.Sequential(encoder.maxpool, encoder.layer1) self.layer2 = encoder.layer2 self.layer3 = encoder.layer3 self.layer4 = encoder.layer4 if num_layers > 34: self.num_ch_enc[1:] *= 4 self.backprojector = BackprojectDepth(batch_size=self.num_depth_bins, height=self.matching_height, width=self.matching_width) self.projector = Project3D(batch_size=self.num_depth_bins, height=self.matching_height, width=self.matching_width) self.compute_depth_bins(min_depth_bin, max_depth_bin) self.prematching_conv = nn.Sequential(nn.Conv2d(64, out_channels=16, kernel_size=1, stride=1, padding=0), nn.ReLU(inplace=True) ) self.reduce_conv = nn.Sequential(nn.Conv2d(self.num_ch_enc[1] + self.num_depth_bins, out_channels=self.num_ch_enc[1], kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True) ) def compute_depth_bins(self, min_depth_bin, max_depth_bin): """Compute the depths bins used to build the cost volume. Bins will depend upon self.depth_binning, to either be linear in depth (linear) or linear in inverse depth (inverse)""" if self.depth_binning == 'inverse': self.depth_bins = 1 / np.linspace(1 / max_depth_bin, 1 / min_depth_bin, self.num_depth_bins)[::-1] # maintain depth order elif self.depth_binning == 'linear': self.depth_bins = np.linspace(min_depth_bin, max_depth_bin, self.num_depth_bins) else: raise NotImplementedError self.depth_bins = torch.from_numpy(self.depth_bins).float() self.warp_depths = [] for depth in self.depth_bins: depth = torch.ones((1, self.matching_height, self.matching_width)) * depth self.warp_depths.append(depth) self.warp_depths = torch.stack(self.warp_depths, 0).float() if self.is_cuda: self.warp_depths = self.warp_depths.cuda() def match_features(self, current_feats, lookup_feats, relative_poses, K, invK): """Compute a cost volume based on L1 difference between current_feats and lookup_feats. We backwards warp the lookup_feats into the current frame using the estimated relative pose, known intrinsics and using hypothesised depths self.warp_depths (which are either linear in depth or linear in inverse depth). If relative_pose == 0 then this indicates that the lookup frame is missing (i.e. we are at the start of a sequence), and so we skip it""" batch_cost_volume = [] # store all cost volumes of the batch cost_volume_masks = [] # store locations of '0's in cost volume for confidence for batch_idx in range(len(current_feats)): volume_shape = (self.num_depth_bins, self.matching_height, self.matching_width) cost_volume = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device) counts = torch.zeros(volume_shape, dtype=torch.float, device=current_feats.device) # select an item from batch of ref feats _lookup_feats = lookup_feats[batch_idx:batch_idx + 1] _lookup_poses = relative_poses[batch_idx:batch_idx + 1] _K = K[batch_idx:batch_idx + 1] _invK = invK[batch_idx:batch_idx + 1] world_points = self.backprojector(self.warp_depths, _invK) # loop through ref images adding to the current cost volume for lookup_idx in range(_lookup_feats.shape[1]): lookup_feat = _lookup_feats[:, lookup_idx] # 1 x C x H x W lookup_pose = _lookup_poses[:, lookup_idx] # ignore missing images if lookup_pose.sum() == 0: continue lookup_feat = lookup_feat.repeat([self.num_depth_bins, 1, 1, 1]) pix_locs = self.projector(world_points, _K, lookup_pose) warped = F.grid_sample(lookup_feat, pix_locs, padding_mode='zeros', mode='bilinear', align_corners=True) # mask values landing outside the image (and near the border) # we want to ignore edge pixels of the lookup images and the current image # because of zero padding in ResNet # Masking of ref image border x_vals = (pix_locs[..., 0].detach() / 2 + 0.5) * ( self.matching_width - 1) # convert from (-1, 1) to pixel values y_vals = (pix_locs[..., 1].detach() / 2 + 0.5) * (self.matching_height - 1) edge_mask = (x_vals >= 2.0) * (x_vals <= self.matching_width - 2) * \ (y_vals >= 2.0) * (y_vals <= self.matching_height - 2) edge_mask = edge_mask.float() # masking of current image current_mask = torch.zeros_like(edge_mask) current_mask[:, 2:-2, 2:-2] = 1.0 edge_mask = edge_mask * current_mask diffs = torch.abs(warped - current_feats[batch_idx:batch_idx + 1]).mean( 1) * edge_mask # integrate into cost volume cost_volume = cost_volume + diffs counts = counts + (diffs > 0).float() # average over lookup images cost_volume = cost_volume / (counts + 1e-7) # if some missing values for a pixel location (i.e. some depths landed outside) then # set to max of existing values missing_val_mask = (cost_volume == 0).float() if self.set_missing_to_max: cost_volume = cost_volume * (1 - missing_val_mask) + \ cost_volume.max(0)[0].unsqueeze(0) * missing_val_mask batch_cost_volume.append(cost_volume) cost_volume_masks.append(missing_val_mask) batch_cost_volume = torch.stack(batch_cost_volume, 0) cost_volume_masks = torch.stack(cost_volume_masks, 0) return batch_cost_volume, cost_volume_masks def feature_extraction(self, image, return_all_feats=False): """ Run feature extraction on an image - first 2 blocks of ResNet""" image = (image - 0.45) / 0.225 # imagenet normalisation feats_0 = self.layer0(image) feats_1 = self.layer1(feats_0) if return_all_feats: return [feats_0, feats_1] else: return feats_1 def indices_to_disparity(self, indices): """Convert cost volume indices to 1/depth for visualisation""" batch, height, width = indices.shape depth = self.depth_bins[indices.reshape(-1).cpu()] disp = 1 / depth.reshape((batch, height, width)) return disp def compute_confidence_mask(self, cost_volume, num_bins_threshold=None): """ Returns a 'confidence' mask based on how many times a depth bin was observed""" if num_bins_threshold is None: num_bins_threshold = self.num_depth_bins confidence_mask = ((cost_volume > 0).sum(1) == num_bins_threshold).float() return confidence_mask def forward(self, current_image, lookup_images, poses, K, invK, min_depth_bin=None, max_depth_bin=None ): # feature extraction self.features = self.feature_extraction(current_image, return_all_feats=True) current_feats = self.features[-1] # print('current_feats:', current_feats.shape) # feature extraction on lookup images - disable gradients to save memory with torch.no_grad(): if self.adaptive_bins: self.compute_depth_bins(min_depth_bin, max_depth_bin) batch_size, num_frames, chns, height, width = lookup_images.shape lookup_images = lookup_images.reshape(batch_size * num_frames, chns, height, width) lookup_feats = self.feature_extraction(lookup_images, return_all_feats=False) _, chns, height, width = lookup_feats.shape lookup_feats = lookup_feats.reshape(batch_size, num_frames, chns, height, width) # print('lookup_feats:', lookup_feats.shape) # warp features to find cost volume cost_volume, missing_mask = \ self.match_features(current_feats, lookup_feats, poses, K, invK) confidence_mask = self.compute_confidence_mask(cost_volume.detach() * (1 - missing_mask.detach())) # for visualisation - ignore 0s in cost volume for minimum viz_cost_vol = cost_volume.clone().detach() viz_cost_vol[viz_cost_vol == 0] = 100 mins, argmin = torch.min(viz_cost_vol, 1) lowest_cost = self.indices_to_disparity(argmin) # mask the cost volume based on the confidence cost_volume *= confidence_mask.unsqueeze(1) post_matching_feats = self.reduce_conv(torch.cat([self.features[-1], cost_volume], 1)) # print('post_matching_feats:', post_matching_feats.shape) self.features.append(self.layer2(post_matching_feats)) self.features.append(self.layer3(self.features[-1])) self.features.append(self.layer4(self.features[-1])) return self.features, lowest_cost, confidence_mask def cuda(self): super().cuda() self.backprojector.cuda() self.projector.cuda() self.is_cuda = True if self.warp_depths is not None: self.warp_depths = self.warp_depths.cuda() def cpu(self): super().cpu() self.backprojector.cpu() self.projector.cpu() self.is_cuda = False if self.warp_depths is not None: self.warp_depths = self.warp_depths.cpu() def to(self, device): if str(device) == 'cpu': self.cpu() elif str(device) == 'cuda': self.cuda() else: raise NotImplementedError class ResnetEncoder(nn.Module): """Pytorch module for a resnet encoder """ def __init__(self, num_layers, pretrained, num_input_images=1, **kwargs): super(ResnetEncoder, self).__init__() self.num_ch_enc = np.array([64, 64, 128, 256, 512]) resnets = {18: models.resnet18, 34: models.resnet34, 50: models.resnet50, 101: models.resnet101, 152: models.resnet152} if num_layers not in resnets: raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) if num_input_images > 1: self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) else: self.encoder = resnets[num_layers](pretrained) if num_layers > 34: self.num_ch_enc[1:] *= 4 def forward(self, input_image): self.features = [] x = (input_image - 0.45) / 0.225 x = self.encoder.conv1(x) x = self.encoder.bn1(x) self.features.append(self.encoder.relu(x)) self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) self.features.append(self.encoder.layer2(self.features[-1])) self.features.append(self.encoder.layer3(self.features[-1])) self.features.append(self.encoder.layer4(self.features[-1])) return self.features