Spaces:
Runtime error
Runtime error
File size: 3,809 Bytes
b4eade4 9ebe7d2 b4eade4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint
from models.mit import mit_b4
class GLPDepth(nn.Module):
def __init__(self, max_depth=10.0, is_train=False):
super().__init__()
self.max_depth = max_depth
self.encoder = mit_b4()
if is_train:
ckpt_path = './mit_b4.pth'
try:
load_checkpoint(self.encoder, ckpt_path, logger=None)
except:
import gdown
print("Download pre-trained encoder weights...")
id = '1BUtU42moYrOFbsMCE-LTTkUE-mrWnfG2'
url = 'https://drive.google.com/uc?id=' + id
output = './models/weights/mit_b4.pth'
gdown.download(url, output, quiet=False)
channels_in = [512, 320, 128]
channels_out = 64
self.decoder = Decoder(channels_in, channels_out)
self.last_layer_depth = nn.Sequential(
nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1))
def forward(self, x):
conv1, conv2, conv3, conv4 = self.encoder(x)
out = self.decoder(conv1, conv2, conv3, conv4)
out_depth = self.last_layer_depth(out)
out_depth = torch.sigmoid(out_depth) * self.max_depth
return {'pred_d': out_depth}
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.bot_conv = nn.Conv2d(
in_channels=in_channels[0], out_channels=out_channels, kernel_size=1)
self.skip_conv1 = nn.Conv2d(
in_channels=in_channels[1], out_channels=out_channels, kernel_size=1)
self.skip_conv2 = nn.Conv2d(
in_channels=in_channels[2], out_channels=out_channels, kernel_size=1)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.fusion1 = SelectiveFeatureFusion(out_channels)
self.fusion2 = SelectiveFeatureFusion(out_channels)
self.fusion3 = SelectiveFeatureFusion(out_channels)
def forward(self, x_1, x_2, x_3, x_4):
x_4_ = self.bot_conv(x_4)
out = self.up(x_4_)
x_3_ = self.skip_conv1(x_3)
out = self.fusion1(x_3_, out)
out = self.up(out)
x_2_ = self.skip_conv2(x_2)
out = self.fusion2(x_2_, out)
out = self.up(out)
out = self.fusion3(x_1, out)
out = self.up(out)
out = self.up(out)
return out
class SelectiveFeatureFusion(nn.Module):
def __init__(self, in_channel=64):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=int(in_channel * 2),
out_channels=in_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_channel),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=in_channel,
out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(int(in_channel / 2)),
nn.ReLU())
self.conv3 = nn.Conv2d(in_channels=int(in_channel / 2),
out_channels=2, kernel_size=3, stride=1, padding=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x_local, x_global):
x = torch.cat((x_local, x_global), dim=1)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
attn = self.sigmoid(x)
out = x_local * attn[:, 0, :, :].unsqueeze(1) + \
x_global * attn[:, 1, :, :].unsqueeze(1)
return out
|