File size: 3,010 Bytes
eea4530
 
 
b2ee272
eea4530
 
 
95377ef
eea4530
 
 
95377ef
eea4530
b2ee272
95377ef
 
 
b2ee272
 
 
 
95377ef
b2ee272
 
 
 
eea4530
 
95377ef
 
eea4530
 
 
 
 
 
 
 
 
95377ef
 
 
 
eea4530
 
 
 
 
 
95377ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eea4530
95377ef
 
 
 
 
 
eea4530
95377ef
eea4530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95377ef
eea4530
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import subprocess

import torch
import cv2
import os
from pathlib import Path
import gradio as gr
from gradio_rerun import Rerun
import spaces

# Run the script to get pretrained models
if not os.path.exists("checkpoints/depth_pro.pt"):
    print("downloading pretrained model")
    subprocess.run(["bash", "get_pretrained_models.sh"])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model and preprocessing transform
print("loading model...")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()


@rr.thread_local_stream("rerun_example_ml_depth_pro")
@spaces.GPU(duration=20)
def run_ml_depth_pro(frame):
    stream = rr.binary_stream()

    assert model is not None, "Model is None"
    assert transform is not None, "Transform is None"
    assert frames is not None, "Frames is None"

    blueprint = rrb.Blueprint(
        rrb.Spatial3DView(origin="/"),
        rrb.Horizontal(
            rrb.Spatial2DView(
                origin="/world/camera/depth",
            ),
            rrb.Spatial2DView(origin="/world/camera/image"),
        ),
        collapse_panels=True,
    )

    rr.send_blueprint(blueprint)

    # for i, frame in enumerate(frames):
    rr.set_time_sequence("frame", 0)
    rr.log("world/camera/image", rr.Image(frame))

    image = transform(frame)
    prediction = model.infer(image)
    depth = prediction["depth"].squeeze().detach().cpu().numpy()

    rr.log(
        "world/camera",
        rr.Pinhole(
            width=frame.shape[1],
            height=frame.shape[0],
            focal_length=prediction["focallength_px"].item(),
            principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
            image_plane_distance=depth.max(),
        ),
    )

    rr.log(
        "world/camera/depth",
        # need 0.19 stable for this
        # rr.DepthImage(depth, meter=1, depth_range=(depth.min(), depth.max())),
        rr.DepthImage(depth, meter=1),
    )

    yield stream.read()


video_path = Path("hd-cat.mp4")


# Load video
frames = []
video = cv2.VideoCapture("hd-cat2.mp4")
while True:
    read, frame = video.read()
    if not read:
        break
    frame = cv2.resize(frame, (320, 240))
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frames.append(frame)

with gr.Blocks() as demo:
    with gr.Tab("Streaming"):
        with gr.Row():
            img = gr.Image(interactive=True, label="Image")
            with gr.Column():
                stream_ml_depth_pro = gr.Button("Stream Ml Depth Pro")
        with gr.Row():
            viewer = Rerun(
                streaming=True,
                panel_states={
                    "time": "collapsed",
                    "blueprint": "hidden",
                    "selection": "hidden",
                },
            )
        stream_ml_depth_pro.click(run_ml_depth_pro, inputs=[img], outputs=[viewer])


if __name__ == "__main__":
    demo.launch()