Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from mmcv.runner import load_checkpoint | |
from models.mit import mit_b4 | |
class GLPDepth(nn.Module): | |
def __init__(self, max_depth=10.0, is_train=False): | |
super().__init__() | |
self.max_depth = max_depth | |
self.encoder = mit_b4() | |
if is_train: | |
ckpt_path = './models/weights/mit_b4.pth' | |
try: | |
load_checkpoint(self.encoder, ckpt_path, logger=None) | |
except: | |
import gdown | |
print("Download pre-trained encoder weights...") | |
id = '1BUtU42moYrOFbsMCE-LTTkUE-mrWnfG2' | |
url = 'https://drive.google.com/uc?id=' + id | |
output = './models/weights/mit_b4.pth' | |
gdown.download(url, output, quiet=False) | |
channels_in = [512, 320, 128] | |
channels_out = 64 | |
self.decoder = Decoder(channels_in, channels_out) | |
self.last_layer_depth = nn.Sequential( | |
nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=False), | |
nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1)) | |
def forward(self, x): | |
conv1, conv2, conv3, conv4 = self.encoder(x) | |
out = self.decoder(conv1, conv2, conv3, conv4) | |
out_depth = self.last_layer_depth(out) | |
out_depth = torch.sigmoid(out_depth) * self.max_depth | |
return {'pred_d': out_depth} | |
class Decoder(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.bot_conv = nn.Conv2d( | |
in_channels=in_channels[0], out_channels=out_channels, kernel_size=1) | |
self.skip_conv1 = nn.Conv2d( | |
in_channels=in_channels[1], out_channels=out_channels, kernel_size=1) | |
self.skip_conv2 = nn.Conv2d( | |
in_channels=in_channels[2], out_channels=out_channels, kernel_size=1) | |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
self.fusion1 = SelectiveFeatureFusion(out_channels) | |
self.fusion2 = SelectiveFeatureFusion(out_channels) | |
self.fusion3 = SelectiveFeatureFusion(out_channels) | |
def forward(self, x_1, x_2, x_3, x_4): | |
x_4_ = self.bot_conv(x_4) | |
out = self.up(x_4_) | |
x_3_ = self.skip_conv1(x_3) | |
out = self.fusion1(x_3_, out) | |
out = self.up(out) | |
x_2_ = self.skip_conv2(x_2) | |
out = self.fusion2(x_2_, out) | |
out = self.up(out) | |
out = self.fusion3(x_1, out) | |
out = self.up(out) | |
out = self.up(out) | |
return out | |
class SelectiveFeatureFusion(nn.Module): | |
def __init__(self, in_channel=64): | |
super().__init__() | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(in_channels=int(in_channel * 2), | |
out_channels=in_channel, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(in_channel), | |
nn.ReLU()) | |
self.conv2 = nn.Sequential( | |
nn.Conv2d(in_channels=in_channel, | |
out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(int(in_channel / 2)), | |
nn.ReLU()) | |
self.conv3 = nn.Conv2d(in_channels=int(in_channel / 2), | |
out_channels=2, kernel_size=3, stride=1, padding=1) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x_local, x_global): | |
x = torch.cat((x_local, x_global), dim=1) | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
attn = self.sigmoid(x) | |
out = x_local * attn[:, 0, :, :].unsqueeze(1) + \ | |
x_global * attn[:, 1, :, :].unsqueeze(1) | |
return out | |