Update app.py
Browse files
app.py
CHANGED
|
@@ -66,7 +66,7 @@ def download_models():
|
|
| 66 |
download_models()
|
| 67 |
|
| 68 |
# DepthAnythingV2
|
| 69 |
-
|
| 70 |
model_configs = {
|
| 71 |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 72 |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
|
@@ -83,22 +83,25 @@ if 'dam2' not in globals():
|
|
| 83 |
dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
|
| 84 |
dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
|
| 85 |
dam2 = dam2.to(DEVICE).eval()
|
|
|
|
| 86 |
|
| 87 |
# GenStereo
|
| 88 |
-
|
| 89 |
genwarp_cfg = dict(
|
| 90 |
pretrained_model_path='checkpoints',
|
| 91 |
checkpoint_name=CHECKPOINT_NAME,
|
| 92 |
half_precision_weights=True
|
| 93 |
)
|
| 94 |
genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
|
|
|
|
| 95 |
|
| 96 |
# Adaptive Fusion
|
| 97 |
-
|
| 98 |
fusion_model = AdaptiveFusionLayer()
|
| 99 |
fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
|
| 100 |
fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
|
| 101 |
fusion_model = fusion_model.to(DEVICE).eval()
|
|
|
|
| 102 |
|
| 103 |
# Crop the image to the shorter side.
|
| 104 |
def crop(img: Image) -> Image:
|
|
@@ -190,6 +193,7 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
| 190 |
|
| 191 |
image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 192 |
|
|
|
|
| 193 |
depth_dam2 = dam2.infer_image(image_bgr)
|
| 194 |
depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float().cuda()
|
| 195 |
|
|
@@ -202,6 +206,9 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
| 202 |
norm_disp = normalize_disp(depth)
|
| 203 |
disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
|
| 204 |
|
|
|
|
|
|
|
|
|
|
| 205 |
renders = genstereo(
|
| 206 |
src_image=image,
|
| 207 |
src_disparity=disp,
|
|
@@ -231,4 +238,4 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
| 231 |
)
|
| 232 |
|
| 233 |
if __name__ == '__main__':
|
| 234 |
-
demo.launch()
|
|
|
|
| 66 |
download_models()
|
| 67 |
|
| 68 |
# DepthAnythingV2
|
| 69 |
+
def get_dam2_model():
|
| 70 |
model_configs = {
|
| 71 |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
| 72 |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
|
|
|
| 83 |
dam2_checkpoint = f'checkpoints/depth_anything_v2_{encoder}.pth'
|
| 84 |
dam2.load_state_dict(torch.load(dam2_checkpoint, map_location='cpu'))
|
| 85 |
dam2 = dam2.to(DEVICE).eval()
|
| 86 |
+
return dam2
|
| 87 |
|
| 88 |
# GenStereo
|
| 89 |
+
def get_genstereo_model():
|
| 90 |
genwarp_cfg = dict(
|
| 91 |
pretrained_model_path='checkpoints',
|
| 92 |
checkpoint_name=CHECKPOINT_NAME,
|
| 93 |
half_precision_weights=True
|
| 94 |
)
|
| 95 |
genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
|
| 96 |
+
return genstereo
|
| 97 |
|
| 98 |
# Adaptive Fusion
|
| 99 |
+
def get_fusion_model():
|
| 100 |
fusion_model = AdaptiveFusionLayer()
|
| 101 |
fusion_checkpoint = join('checkpoints', CHECKPOINT_NAME, 'fusion_layer.pth')
|
| 102 |
fusion_model.load_state_dict(torch.load(fusion_checkpoint, map_location='cpu'))
|
| 103 |
fusion_model = fusion_model.to(DEVICE).eval()
|
| 104 |
+
return fusion_model
|
| 105 |
|
| 106 |
# Crop the image to the shorter side.
|
| 107 |
def crop(img: Image) -> Image:
|
|
|
|
| 193 |
|
| 194 |
image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 195 |
|
| 196 |
+
dam2 = get_dam2_model()
|
| 197 |
depth_dam2 = dam2.infer_image(image_bgr)
|
| 198 |
depth = torch.tensor(depth_dam2).unsqueeze(0).unsqueeze(0).float().cuda()
|
| 199 |
|
|
|
|
| 206 |
norm_disp = normalize_disp(depth)
|
| 207 |
disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
|
| 208 |
|
| 209 |
+
genstereo = get_genstereo_model()
|
| 210 |
+
fusion_model = get_fusion_model()
|
| 211 |
+
|
| 212 |
renders = genstereo(
|
| 213 |
src_image=image,
|
| 214 |
src_disparity=disp,
|
|
|
|
| 238 |
)
|
| 239 |
|
| 240 |
if __name__ == '__main__':
|
| 241 |
+
demo.launch()
|