Spaces:
Runtime error
Runtime error
from lib import network_auxi as network | |
from lib.net_tools import get_func | |
import torch | |
import torch.nn as nn | |
class RelDepthModel(nn.Module): | |
def __init__(self, backbone='resnet50'): | |
super(RelDepthModel, self).__init__() | |
if backbone == 'resnet50': | |
encoder = 'resnet50_stride32' | |
elif backbone == 'resnext101': | |
encoder = 'resnext101_stride32x8d' | |
self.depth_model = DepthModel(encoder) | |
def inference(self, rgb): | |
with torch.no_grad(): | |
input = rgb.cuda() | |
depth = self.depth_model(input) | |
#pred_depth_out = depth - depth.min() + 0.01 | |
return depth #pred_depth_out | |
class DepthModel(nn.Module): | |
def __init__(self, encoder): | |
super(DepthModel, self).__init__() | |
backbone = network.__name__.split('.')[-1] + '.' + encoder | |
self.encoder_modules = get_func(backbone)() | |
self.decoder_modules = network.Decoder() | |
def forward(self, x): | |
lateral_out = self.encoder_modules(x) | |
out_logit = self.decoder_modules(lateral_out) | |
return out_logit |