|
|
|
|
|
|
|
|
|
|
|
import copy |
|
from functools import partial |
|
import math |
|
import warnings |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .ops import resize |
|
|
|
|
|
|
|
def _imdenormalize(img, mean, std, to_bgr=True): |
|
import numpy as np |
|
|
|
mean = mean.reshape(1, -1).astype(np.float64) |
|
std = std.reshape(1, -1).astype(np.float64) |
|
img = (img * std) + mean |
|
if to_bgr: |
|
img = img[::-1] |
|
return img |
|
|
|
|
|
class DepthBaseDecodeHead(nn.Module): |
|
"""Base class for BaseDecodeHead. |
|
|
|
Args: |
|
in_channels (List): Input channels. |
|
channels (int): Channels after modules, before conv_depth. |
|
conv_layer (nn.Module): Conv layers. Default: None. |
|
act_layer (nn.Module): Activation layers. Default: nn.ReLU. |
|
loss_decode (dict): Config of decode loss. |
|
Default: (). |
|
sampler (dict|None): The config of depth map sampler. |
|
Default: None. |
|
align_corners (bool): align_corners argument of F.interpolate. |
|
Default: False. |
|
min_depth (int): Min depth in dataset setting. |
|
Default: 1e-3. |
|
max_depth (int): Max depth in dataset setting. |
|
Default: None. |
|
norm_layer (dict|None): Norm layers. |
|
Default: None. |
|
classify (bool): Whether predict depth in a cls.-reg. manner. |
|
Default: False. |
|
n_bins (int): The number of bins used in cls. step. |
|
Default: 256. |
|
bins_strategy (str): The discrete strategy used in cls. step. |
|
Default: 'UD'. |
|
norm_strategy (str): The norm strategy on cls. probability |
|
distribution. Default: 'linear' |
|
scale_up (str): Whether predict depth in a scale-up manner. |
|
Default: False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
conv_layer=None, |
|
act_layer=nn.ReLU, |
|
channels=96, |
|
loss_decode=(), |
|
sampler=None, |
|
align_corners=False, |
|
min_depth=1e-3, |
|
max_depth=None, |
|
norm_layer=None, |
|
classify=False, |
|
n_bins=256, |
|
bins_strategy="UD", |
|
norm_strategy="linear", |
|
scale_up=False, |
|
): |
|
super(DepthBaseDecodeHead, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.channels = channels |
|
self.conf_layer = conv_layer |
|
self.act_layer = act_layer |
|
self.loss_decode = loss_decode |
|
self.align_corners = align_corners |
|
self.min_depth = min_depth |
|
self.max_depth = max_depth |
|
self.norm_layer = norm_layer |
|
self.classify = classify |
|
self.n_bins = n_bins |
|
self.scale_up = scale_up |
|
|
|
if self.classify: |
|
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" |
|
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" |
|
|
|
self.bins_strategy = bins_strategy |
|
self.norm_strategy = norm_strategy |
|
self.softmax = nn.Softmax(dim=1) |
|
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) |
|
else: |
|
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) |
|
|
|
self.relu = nn.ReLU() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, inputs, img_metas): |
|
"""Placeholder of forward function.""" |
|
pass |
|
|
|
def forward_train(self, img, inputs, img_metas, depth_gt): |
|
"""Forward function for training. |
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
img_metas (list[dict]): List of image info dict where each dict |
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain |
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
|
For details on the values of these keys see |
|
`depth/datasets/pipelines/formatting.py:Collect`. |
|
depth_gt (Tensor): GT depth |
|
|
|
Returns: |
|
dict[str, Tensor]: a dictionary of loss components |
|
""" |
|
depth_pred = self.forward(inputs, img_metas) |
|
losses = self.losses(depth_pred, depth_gt) |
|
|
|
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) |
|
losses.update(**log_imgs) |
|
|
|
return losses |
|
|
|
def forward_test(self, inputs, img_metas): |
|
"""Forward function for testing. |
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
img_metas (list[dict]): List of image info dict where each dict |
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain |
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
|
For details on the values of these keys see |
|
`depth/datasets/pipelines/formatting.py:Collect`. |
|
|
|
Returns: |
|
Tensor: Output depth map. |
|
""" |
|
return self.forward(inputs, img_metas) |
|
|
|
def depth_pred(self, feat): |
|
"""Prediction each pixel.""" |
|
if self.classify: |
|
logit = self.conv_depth(feat) |
|
|
|
if self.bins_strategy == "UD": |
|
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) |
|
elif self.bins_strategy == "SID": |
|
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) |
|
|
|
|
|
if self.norm_strategy == "linear": |
|
logit = torch.relu(logit) |
|
eps = 0.1 |
|
logit = logit + eps |
|
logit = logit / logit.sum(dim=1, keepdim=True) |
|
elif self.norm_strategy == "softmax": |
|
logit = torch.softmax(logit, dim=1) |
|
elif self.norm_strategy == "sigmoid": |
|
logit = torch.sigmoid(logit) |
|
logit = logit / logit.sum(dim=1, keepdim=True) |
|
|
|
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) |
|
|
|
else: |
|
if self.scale_up: |
|
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth |
|
else: |
|
output = self.relu(self.conv_depth(feat)) + self.min_depth |
|
return output |
|
|
|
def losses(self, depth_pred, depth_gt): |
|
"""Compute depth loss.""" |
|
loss = dict() |
|
depth_pred = resize( |
|
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False |
|
) |
|
if not isinstance(self.loss_decode, nn.ModuleList): |
|
losses_decode = [self.loss_decode] |
|
else: |
|
losses_decode = self.loss_decode |
|
for loss_decode in losses_decode: |
|
if loss_decode.loss_name not in loss: |
|
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) |
|
else: |
|
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) |
|
return loss |
|
|
|
def log_images(self, img_path, depth_pred, depth_gt, img_meta): |
|
import numpy as np |
|
|
|
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) |
|
show_img = show_img.numpy().astype(np.float32) |
|
show_img = _imdenormalize( |
|
show_img, |
|
img_meta["img_norm_cfg"]["mean"], |
|
img_meta["img_norm_cfg"]["std"], |
|
img_meta["img_norm_cfg"]["to_rgb"], |
|
) |
|
show_img = np.clip(show_img, 0, 255) |
|
show_img = show_img.astype(np.uint8) |
|
show_img = show_img[:, :, ::-1] |
|
show_img = show_img.transpose(0, 2, 1) |
|
show_img = show_img.transpose(1, 0, 2) |
|
|
|
depth_pred = depth_pred / torch.max(depth_pred) |
|
depth_gt = depth_gt / torch.max(depth_gt) |
|
|
|
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) |
|
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) |
|
|
|
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} |
|
|
|
|
|
class BNHead(DepthBaseDecodeHead): |
|
"""Just a batchnorm.""" |
|
|
|
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): |
|
super().__init__(**kwargs) |
|
self.input_transform = input_transform |
|
self.in_index = in_index |
|
self.upsample = upsample |
|
|
|
if self.classify: |
|
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) |
|
else: |
|
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) |
|
|
|
def _transform_inputs(self, inputs): |
|
"""Transform inputs for decoder. |
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
Returns: |
|
Tensor: The transformed inputs |
|
""" |
|
|
|
if "concat" in self.input_transform: |
|
inputs = [inputs[i] for i in self.in_index] |
|
if "resize" in self.input_transform: |
|
inputs = [ |
|
resize( |
|
input=x, |
|
size=[s * self.upsample for s in inputs[0].shape[2:]], |
|
mode="bilinear", |
|
align_corners=self.align_corners, |
|
) |
|
for x in inputs |
|
] |
|
inputs = torch.cat(inputs, dim=1) |
|
elif self.input_transform == "multiple_select": |
|
inputs = [inputs[i] for i in self.in_index] |
|
else: |
|
inputs = inputs[self.in_index] |
|
|
|
return inputs |
|
|
|
def _forward_feature(self, inputs, img_metas=None, **kwargs): |
|
"""Forward function for feature maps before classifying each pixel with |
|
``self.cls_seg`` fc. |
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
Returns: |
|
feats (Tensor): A tensor of shape (batch_size, self.channels, |
|
H, W) which is feature map for last layer of decoder head. |
|
""" |
|
|
|
inputs = list(inputs) |
|
for i, x in enumerate(inputs): |
|
if len(x) == 2: |
|
x, cls_token = x[0], x[1] |
|
if len(x.shape) == 2: |
|
x = x[:, :, None, None] |
|
cls_token = cls_token[:, :, None, None].expand_as(x) |
|
inputs[i] = torch.cat((x, cls_token), 1) |
|
else: |
|
x = x[0] |
|
if len(x.shape) == 2: |
|
x = x[:, :, None, None] |
|
inputs[i] = x |
|
x = self._transform_inputs(inputs) |
|
|
|
return x |
|
|
|
def forward(self, inputs, img_metas=None, **kwargs): |
|
"""Forward function.""" |
|
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) |
|
output = self.depth_pred(output) |
|
return output |
|
|
|
|
|
class ConvModule(nn.Module): |
|
"""A conv block that bundles conv/norm/activation layers. |
|
|
|
This block simplifies the usage of convolution layers, which are commonly |
|
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). |
|
It is based upon three build methods: `build_conv_layer()`, |
|
`build_norm_layer()` and `build_activation_layer()`. |
|
|
|
Besides, we add some additional features in this module. |
|
1. Automatically set `bias` of the conv layer. |
|
2. Spectral norm is supported. |
|
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only |
|
supports zero and circular padding, and we add "reflect" padding mode. |
|
|
|
Args: |
|
in_channels (int): Number of channels in the input feature map. |
|
Same as that in ``nn._ConvNd``. |
|
out_channels (int): Number of channels produced by the convolution. |
|
Same as that in ``nn._ConvNd``. |
|
kernel_size (int | tuple[int]): Size of the convolving kernel. |
|
Same as that in ``nn._ConvNd``. |
|
stride (int | tuple[int]): Stride of the convolution. |
|
Same as that in ``nn._ConvNd``. |
|
padding (int | tuple[int]): Zero-padding added to both sides of |
|
the input. Same as that in ``nn._ConvNd``. |
|
dilation (int | tuple[int]): Spacing between kernel elements. |
|
Same as that in ``nn._ConvNd``. |
|
groups (int): Number of blocked connections from input channels to |
|
output channels. Same as that in ``nn._ConvNd``. |
|
bias (bool | str): If specified as `auto`, it will be decided by the |
|
norm_layer. Bias will be set as True if `norm_layer` is None, otherwise |
|
False. Default: "auto". |
|
conv_layer (nn.Module): Convolution layer. Default: None, |
|
which means using conv2d. |
|
norm_layer (nn.Module): Normalization layer. Default: None. |
|
act_layer (nn.Module): Activation layer. Default: nn.ReLU. |
|
inplace (bool): Whether to use inplace mode for activation. |
|
Default: True. |
|
with_spectral_norm (bool): Whether use spectral norm in conv module. |
|
Default: False. |
|
padding_mode (str): If the `padding_mode` has not been supported by |
|
current `Conv2d` in PyTorch, we will use our own padding layer |
|
instead. Currently, we support ['zeros', 'circular'] with official |
|
implementation and ['reflect'] with our own implementation. |
|
Default: 'zeros'. |
|
order (tuple[str]): The order of conv/norm/activation layers. It is a |
|
sequence of "conv", "norm" and "act". Common examples are |
|
("conv", "norm", "act") and ("act", "conv", "norm"). |
|
Default: ('conv', 'norm', 'act'). |
|
""" |
|
|
|
_abbr_ = "conv_block" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
groups=1, |
|
bias="auto", |
|
conv_layer=nn.Conv2d, |
|
norm_layer=None, |
|
act_layer=nn.ReLU, |
|
inplace=True, |
|
with_spectral_norm=False, |
|
padding_mode="zeros", |
|
order=("conv", "norm", "act"), |
|
): |
|
super(ConvModule, self).__init__() |
|
official_padding_mode = ["zeros", "circular"] |
|
self.conv_layer = conv_layer |
|
self.norm_layer = norm_layer |
|
self.act_layer = act_layer |
|
self.inplace = inplace |
|
self.with_spectral_norm = with_spectral_norm |
|
self.with_explicit_padding = padding_mode not in official_padding_mode |
|
self.order = order |
|
assert isinstance(self.order, tuple) and len(self.order) == 3 |
|
assert set(order) == set(["conv", "norm", "act"]) |
|
|
|
self.with_norm = norm_layer is not None |
|
self.with_activation = act_layer is not None |
|
|
|
if bias == "auto": |
|
bias = not self.with_norm |
|
self.with_bias = bias |
|
|
|
if self.with_explicit_padding: |
|
if padding_mode == "zeros": |
|
padding_layer = nn.ZeroPad2d |
|
else: |
|
raise AssertionError(f"Unsupported padding mode: {padding_mode}") |
|
self.pad = padding_layer(padding) |
|
|
|
|
|
conv_padding = 0 if self.with_explicit_padding else padding |
|
|
|
self.conv = self.conv_layer( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=stride, |
|
padding=conv_padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
|
|
self.in_channels = self.conv.in_channels |
|
self.out_channels = self.conv.out_channels |
|
self.kernel_size = self.conv.kernel_size |
|
self.stride = self.conv.stride |
|
self.padding = padding |
|
self.dilation = self.conv.dilation |
|
self.transposed = self.conv.transposed |
|
self.output_padding = self.conv.output_padding |
|
self.groups = self.conv.groups |
|
|
|
if self.with_spectral_norm: |
|
self.conv = nn.utils.spectral_norm(self.conv) |
|
|
|
|
|
if self.with_norm: |
|
|
|
if order.index("norm") > order.index("conv"): |
|
norm_channels = out_channels |
|
else: |
|
norm_channels = in_channels |
|
norm = partial(norm_layer, num_features=norm_channels) |
|
self.add_module("norm", norm) |
|
if self.with_bias: |
|
from torch.nnModules.batchnorm import _BatchNorm |
|
from torch.nnModules.instancenorm import _InstanceNorm |
|
|
|
if isinstance(norm, (_BatchNorm, _InstanceNorm)): |
|
warnings.warn("Unnecessary conv bias before batch/instance norm") |
|
else: |
|
self.norm_name = None |
|
|
|
|
|
if self.with_activation: |
|
|
|
|
|
if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): |
|
act_layer = partial(act_layer, inplace=inplace) |
|
self.activate = act_layer() |
|
|
|
|
|
self.init_weights() |
|
|
|
@property |
|
def norm(self): |
|
if self.norm_name: |
|
return getattr(self, self.norm_name) |
|
else: |
|
return None |
|
|
|
def init_weights(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(self.conv, "init_weights"): |
|
if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): |
|
nonlinearity = "leaky_relu" |
|
a = 0.01 |
|
else: |
|
nonlinearity = "relu" |
|
a = 0 |
|
if hasattr(self.conv, "weight") and self.conv.weight is not None: |
|
nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) |
|
if hasattr(self.conv, "bias") and self.conv.bias is not None: |
|
nn.init.constant_(self.conv.bias, 0) |
|
if self.with_norm: |
|
if hasattr(self.norm, "weight") and self.norm.weight is not None: |
|
nn.init.constant_(self.norm.weight, 1) |
|
if hasattr(self.norm, "bias") and self.norm.bias is not None: |
|
nn.init.constant_(self.norm.bias, 0) |
|
|
|
def forward(self, x, activate=True, norm=True): |
|
for layer in self.order: |
|
if layer == "conv": |
|
if self.with_explicit_padding: |
|
x = self.pad(x) |
|
x = self.conv(x) |
|
elif layer == "norm" and norm and self.with_norm: |
|
x = self.norm(x) |
|
elif layer == "act" and activate and self.with_activation: |
|
x = self.activate(x) |
|
return x |
|
|
|
|
|
class Interpolate(nn.Module): |
|
def __init__(self, scale_factor, mode, align_corners=False): |
|
super(Interpolate, self).__init__() |
|
self.interp = nn.functional.interpolate |
|
self.scale_factor = scale_factor |
|
self.mode = mode |
|
self.align_corners = align_corners |
|
|
|
def forward(self, x): |
|
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) |
|
return x |
|
|
|
|
|
class HeadDepth(nn.Module): |
|
def __init__(self, features): |
|
super(HeadDepth, self).__init__() |
|
self.head = nn.Sequential( |
|
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), |
|
Interpolate(scale_factor=2, mode="bilinear", align_corners=True), |
|
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.head(x) |
|
return x |
|
|
|
|
|
class ReassembleBlocks(nn.Module): |
|
"""ViTPostProcessBlock, process cls_token in ViT backbone output and |
|
rearrange the feature vector to feature map. |
|
Args: |
|
in_channels (int): ViT feature channels. Default: 768. |
|
out_channels (List): output channels of each stage. |
|
Default: [96, 192, 384, 768]. |
|
readout_type (str): Type of readout operation. Default: 'ignore'. |
|
patch_size (int): The patch size. Default: 16. |
|
""" |
|
|
|
def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): |
|
super(ReassembleBlocks, self).__init__() |
|
|
|
assert readout_type in ["ignore", "add", "project"] |
|
self.readout_type = readout_type |
|
self.patch_size = patch_size |
|
|
|
self.projects = nn.ModuleList( |
|
[ |
|
ConvModule( |
|
in_channels=in_channels, |
|
out_channels=out_channel, |
|
kernel_size=1, |
|
act_layer=None, |
|
) |
|
for out_channel in out_channels |
|
] |
|
) |
|
|
|
self.resize_layers = nn.ModuleList( |
|
[ |
|
nn.ConvTranspose2d( |
|
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 |
|
), |
|
nn.ConvTranspose2d( |
|
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 |
|
), |
|
nn.Identity(), |
|
nn.Conv2d( |
|
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 |
|
), |
|
] |
|
) |
|
if self.readout_type == "project": |
|
self.readout_projects = nn.ModuleList() |
|
for _ in range(len(self.projects)): |
|
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) |
|
|
|
def forward(self, inputs): |
|
assert isinstance(inputs, list) |
|
out = [] |
|
for i, x in enumerate(inputs): |
|
assert len(x) == 2 |
|
x, cls_token = x[0], x[1] |
|
feature_shape = x.shape |
|
if self.readout_type == "project": |
|
x = x.flatten(2).permute((0, 2, 1)) |
|
readout = cls_token.unsqueeze(1).expand_as(x) |
|
x = self.readout_projects[i](torch.cat((x, readout), -1)) |
|
x = x.permute(0, 2, 1).reshape(feature_shape) |
|
elif self.readout_type == "add": |
|
x = x.flatten(2) + cls_token.unsqueeze(-1) |
|
x = x.reshape(feature_shape) |
|
else: |
|
pass |
|
x = self.projects[i](x) |
|
x = self.resize_layers[i](x) |
|
out.append(x) |
|
return out |
|
|
|
|
|
class PreActResidualConvUnit(nn.Module): |
|
"""ResidualConvUnit, pre-activate residual unit. |
|
Args: |
|
in_channels (int): number of channels in the input feature map. |
|
act_layer (nn.Module): activation layer. |
|
norm_layer (nn.Module): norm layer. |
|
stride (int): stride of the first block. Default: 1 |
|
dilation (int): dilation rate for convs layers. Default: 1. |
|
""" |
|
|
|
def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): |
|
super(PreActResidualConvUnit, self).__init__() |
|
|
|
self.conv1 = ConvModule( |
|
in_channels, |
|
in_channels, |
|
3, |
|
stride=stride, |
|
padding=dilation, |
|
dilation=dilation, |
|
norm_layer=norm_layer, |
|
act_layer=act_layer, |
|
bias=False, |
|
order=("act", "conv", "norm"), |
|
) |
|
|
|
self.conv2 = ConvModule( |
|
in_channels, |
|
in_channels, |
|
3, |
|
padding=1, |
|
norm_layer=norm_layer, |
|
act_layer=act_layer, |
|
bias=False, |
|
order=("act", "conv", "norm"), |
|
) |
|
|
|
def forward(self, inputs): |
|
inputs_ = inputs.clone() |
|
x = self.conv1(inputs) |
|
x = self.conv2(x) |
|
return x + inputs_ |
|
|
|
|
|
class FeatureFusionBlock(nn.Module): |
|
"""FeatureFusionBlock, merge feature map from different stages. |
|
Args: |
|
in_channels (int): Input channels. |
|
act_layer (nn.Module): activation layer for ResidualConvUnit. |
|
norm_layer (nn.Module): normalization layer. |
|
expand (bool): Whether expand the channels in post process block. |
|
Default: False. |
|
align_corners (bool): align_corner setting for bilinear upsample. |
|
Default: True. |
|
""" |
|
|
|
def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): |
|
super(FeatureFusionBlock, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.expand = expand |
|
self.align_corners = align_corners |
|
|
|
self.out_channels = in_channels |
|
if self.expand: |
|
self.out_channels = in_channels // 2 |
|
|
|
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) |
|
|
|
self.res_conv_unit1 = PreActResidualConvUnit( |
|
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer |
|
) |
|
self.res_conv_unit2 = PreActResidualConvUnit( |
|
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer |
|
) |
|
|
|
def forward(self, *inputs): |
|
x = inputs[0] |
|
if len(inputs) == 2: |
|
if x.shape != inputs[1].shape: |
|
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) |
|
else: |
|
res = inputs[1] |
|
x = x + self.res_conv_unit1(res) |
|
x = self.res_conv_unit2(x) |
|
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) |
|
x = self.project(x) |
|
return x |
|
|
|
|
|
class DPTHead(DepthBaseDecodeHead): |
|
"""Vision Transformers for Dense Prediction. |
|
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_. |
|
Args: |
|
embed_dims (int): The embed dimension of the ViT backbone. |
|
Default: 768. |
|
post_process_channels (List): Out channels of post process conv |
|
layers. Default: [96, 192, 384, 768]. |
|
readout_type (str): Type of readout operation. Default: 'ignore'. |
|
patch_size (int): The patch size. Default: 16. |
|
expand_channels (bool): Whether expand the channels in post process |
|
block. Default: False. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dims=768, |
|
post_process_channels=[96, 192, 384, 768], |
|
readout_type="ignore", |
|
patch_size=16, |
|
expand_channels=False, |
|
**kwargs, |
|
): |
|
super(DPTHead, self).__init__(**kwargs) |
|
|
|
self.in_channels = self.in_channels |
|
self.expand_channels = expand_channels |
|
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) |
|
|
|
self.post_process_channels = [ |
|
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) |
|
] |
|
self.convs = nn.ModuleList() |
|
for channel in self.post_process_channels: |
|
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) |
|
self.fusion_blocks = nn.ModuleList() |
|
for _ in range(len(self.convs)): |
|
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) |
|
self.fusion_blocks[0].res_conv_unit1 = None |
|
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) |
|
self.num_fusion_blocks = len(self.fusion_blocks) |
|
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) |
|
self.num_post_process_channels = len(self.post_process_channels) |
|
assert self.num_fusion_blocks == self.num_reassemble_blocks |
|
assert self.num_reassemble_blocks == self.num_post_process_channels |
|
self.conv_depth = HeadDepth(self.channels) |
|
|
|
def forward(self, inputs, img_metas): |
|
assert len(inputs) == self.num_reassemble_blocks |
|
x = [inp for inp in inputs] |
|
x = self.reassemble_blocks(x) |
|
x = [self.convs[i](feature) for i, feature in enumerate(x)] |
|
out = self.fusion_blocks[0](x[-1]) |
|
for i in range(1, len(self.fusion_blocks)): |
|
out = self.fusion_blocks[i](out, x[-(i + 1)]) |
|
out = self.project(out) |
|
out = self.depth_pred(out) |
|
return out |
|
|