Spaces:
Runtime error
Runtime error
File size: 1,136 Bytes
e04dce3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from lib import network_auxi as network
from lib.net_tools import get_func
import torch
import torch.nn as nn
class RelDepthModel(nn.Module):
def __init__(self, backbone='resnet50'):
super(RelDepthModel, self).__init__()
if backbone == 'resnet50':
encoder = 'resnet50_stride32'
elif backbone == 'resnext101':
encoder = 'resnext101_stride32x8d'
self.depth_model = DepthModel(encoder)
def inference(self, rgb):
with torch.no_grad():
input = rgb.cuda()
depth = self.depth_model(input)
#pred_depth_out = depth - depth.min() + 0.01
return depth #pred_depth_out
class DepthModel(nn.Module):
def __init__(self, encoder):
super(DepthModel, self).__init__()
backbone = network.__name__.split('.')[-1] + '.' + encoder
self.encoder_modules = get_func(backbone)()
self.decoder_modules = network.Decoder()
def forward(self, x):
lateral_out = self.encoder_modules(x)
out_logit = self.decoder_modules(lateral_out)
return out_logit |