# Copyright Niantic 2021. Patent Pending. All rights reserved. # # This software is licensed under the terms of the ManyDepth licence # which allows for non-commercial use only, the full terms of which are made # available in the LICENSE file. import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict class ConvBlock(nn.Module): """Layer to perform a convolution followed by ELU """ def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() self.conv = Conv3x3(in_channels, out_channels) self.nonlin = nn.ELU(inplace=True) def forward(self, x): out = self.conv(x) out = self.nonlin(out) return out class Conv3x3(nn.Module): """Layer to pad and convolve input """ def __init__(self, in_channels, out_channels, use_refl=True): super(Conv3x3, self).__init__() if use_refl: self.pad = nn.ReflectionPad2d(1) else: self.pad = nn.ZeroPad2d(1) self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) def forward(self, x): out = self.pad(x) out = self.conv(out) return out def upsample(x): """Upsample input tensor by a factor of 2 """ return F.interpolate(x, scale_factor=2, mode="nearest") class DepthDecoder(nn.Module): def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): super(DepthDecoder, self).__init__() self.num_output_channels = num_output_channels self.use_skips = use_skips self.upsample_mode = 'nearest' self.scales = scales self.num_ch_enc = num_ch_enc self.num_ch_dec = np.array([16, 32, 64, 128, 256]) # decoder self.convs = OrderedDict() for i in range(4, -1, -1): # upconv_0 num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) # upconv_1 num_ch_in = self.num_ch_dec[i] if self.use_skips and i > 0: num_ch_in += self.num_ch_enc[i - 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) for s in self.scales: self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) self.decoder = nn.ModuleList(list(self.convs.values())) self.sigmoid = nn.Sigmoid() def forward(self, input_features): self.outputs = {} # decoder x = input_features[-1] for i in range(4, -1, -1): x = self.convs[("upconv", i, 0)](x) x = [upsample(x)] if self.use_skips and i > 0: x += [input_features[i - 1]] x = torch.cat(x, 1) x = self.convs[("upconv", i, 1)](x) if i in self.scales: self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x)) return self.outputs