import torch import torch.nn as nn from mono.utils.comm import get_func class EncoderDecoder(nn.Module): def __init__(self, cfg): super(EncoderDecoder, self).__init__() self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone) self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg) self.depth_out_head = DepthOutHead(method=cfg.model.depth_out_head.method, **cfg) self.training = True def forward(self, input, **kwargs): # [f_32, f_16, f_8, f_4] features = self.encoder(input) # [x_32, x_16, x_8, x_4, x, ...] decode_list = self.decoder(features) pred, conf, logit, bins_edges = self.depth_out_head([decode_list[4], ]) auxi_preds = None auxi_logits = None out = dict( prediction=pred[0], confidence=conf[0], pred_logit=logit[0], auxi_pred=auxi_preds, auxi_logit_list=auxi_logits, bins_edges=bins_edges[0], ) return out