File size: 3,220 Bytes
eea4530
 
 
b2ee272
eea4530
 
 
4d29a77
95377ef
eea4530
 
 
95377ef
eea4530
b2ee272
cb5d809
95377ef
 
b2ee272
 
 
 
95377ef
b2ee272
 
 
 
eea4530
4d29a77
eec6e0c
eea4530
 
 
afc0455
 
 
 
 
 
 
95377ef
0c06861
eea4530
 
 
 
eec6e0c
eea4530
eec6e0c
 
91480dd
eec6e0c
 
40c7c34
eec6e0c
95377ef
40c7c34
cb5d809
40c7c34
 
157bcb9
 
cb5d809
 
eec6e0c
4d29a77
eec6e0c
0c06861
91480dd
eec6e0c
 
c4d069c
eec6e0c
 
 
 
 
 
0c06861
eec6e0c
 
e8629d6
eec6e0c
 
95377ef
eec6e0c
 
 
 
eea4530
eec6e0c
eea4530
cb5d809
0fde623
cb5d809
 
eea4530
157bcb9
c4d069c
 
 
eec6e0c
 
 
eea4530
eec6e0c
 
 
eea4530
eec6e0c
e8629d6
 
 
 
 
 
 
eec6e0c
eea4530
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import subprocess

import torch
import cv2
import numpy as np
import os
from pathlib import Path
import gradio as gr
from gradio_rerun import Rerun
import spaces

# Run the script to get pretrained models
if not os.path.exists("./checkpoints/depth_pro.pt"):
    print("downloading pretrained model")
    subprocess.run(["bash", "get_pretrained_models.sh"])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model and preprocessing transform
print("loading model...")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()


@rr.thread_local_stream("rerun_example_ml_depth_pro")
def run_rerun(path_to_video):
    stream = rr.binary_stream()

    blueprint = rrb.Blueprint(
        rrb.Vertical(
            rrb.Spatial3DView(origin="/"),
            rrb.Horizontal(
                rrb.Spatial2DView(
                    origin="/world/camera/depth",
                ),
                rrb.Spatial2DView(origin="/world/camera/image"),
            ),
        ),
        collapse_panels=True,
    )

    rr.send_blueprint(blueprint)
    yield stream.read()

    print("Loading video from", path_to_video)
    video = cv2.VideoCapture(path_to_video)
    frame_idx = -1
    while True:
        read, frame = video.read()
        if not read:
            break

        frame_idx += 1
        if frame_idx % 3 != 0:
            continue

        print("processing frame", frame_idx)

        # resize to avoid excessive time spent processing
        frame = cv2.resize(frame, (640, 480))
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        rr.set_time_sequence("frame", frame_idx)
        rr.log("world/camera/image", rr.Image(frame))

        yield stream.read()

        depth, focal_length = estimate_depth(frame)

        rr.log(
            "world/camera",
            rr.Pinhole(
                width=frame.shape[1],
                height=frame.shape[0],
                focal_length=focal_length,
                principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
                image_plane_distance=depth.max(),
                camera_xyz=rr.ViewCoordinates.FLU,
            ),
        )

        rr.log(
            "world/camera/depth",
            rr.DepthImage(depth, meter=1),
        )

        yield stream.read()

    # clean up
    if os.path.exists(path_to_video):
        os.remove(path_to_video)


@spaces.GPU(duration=20)
def estimate_depth(frame):
    image = transform(frame)
    image = image.to(device)
    prediction = model.infer(image)
    depth = prediction["depth"].squeeze().detach().cpu().numpy()
    focal_length = prediction["focallength_px"].item()

    return depth, focal_length


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(variant="compact"):
            video = gr.Video(interactive=True, include_audio=False, label="Video")
            visualize = gr.Button("Visualize ML Depth Pro")
        with gr.Column():
            viewer = Rerun(
                streaming=True,
            )
        visualize.click(run_rerun, inputs=[video], outputs=[viewer])


if __name__ == "__main__":
    demo.launch()