Spaces:
Sleeping
Sleeping
File size: 3,220 Bytes
eea4530 b2ee272 eea4530 4d29a77 95377ef eea4530 95377ef eea4530 b2ee272 cb5d809 95377ef b2ee272 95377ef b2ee272 eea4530 4d29a77 eec6e0c eea4530 afc0455 95377ef 0c06861 eea4530 eec6e0c eea4530 eec6e0c 91480dd eec6e0c 40c7c34 eec6e0c 95377ef 40c7c34 cb5d809 40c7c34 157bcb9 cb5d809 eec6e0c 4d29a77 eec6e0c 0c06861 91480dd eec6e0c c4d069c eec6e0c 0c06861 eec6e0c e8629d6 eec6e0c 95377ef eec6e0c eea4530 eec6e0c eea4530 cb5d809 0fde623 cb5d809 eea4530 157bcb9 c4d069c eec6e0c eea4530 eec6e0c eea4530 eec6e0c e8629d6 eec6e0c eea4530 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import subprocess
import torch
import cv2
import numpy as np
import os
from pathlib import Path
import gradio as gr
from gradio_rerun import Rerun
import spaces
# Run the script to get pretrained models
if not os.path.exists("./checkpoints/depth_pro.pt"):
print("downloading pretrained model")
subprocess.run(["bash", "get_pretrained_models.sh"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model and preprocessing transform
print("loading model...")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()
@rr.thread_local_stream("rerun_example_ml_depth_pro")
def run_rerun(path_to_video):
stream = rr.binary_stream()
blueprint = rrb.Blueprint(
rrb.Vertical(
rrb.Spatial3DView(origin="/"),
rrb.Horizontal(
rrb.Spatial2DView(
origin="/world/camera/depth",
),
rrb.Spatial2DView(origin="/world/camera/image"),
),
),
collapse_panels=True,
)
rr.send_blueprint(blueprint)
yield stream.read()
print("Loading video from", path_to_video)
video = cv2.VideoCapture(path_to_video)
frame_idx = -1
while True:
read, frame = video.read()
if not read:
break
frame_idx += 1
if frame_idx % 3 != 0:
continue
print("processing frame", frame_idx)
# resize to avoid excessive time spent processing
frame = cv2.resize(frame, (640, 480))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
rr.set_time_sequence("frame", frame_idx)
rr.log("world/camera/image", rr.Image(frame))
yield stream.read()
depth, focal_length = estimate_depth(frame)
rr.log(
"world/camera",
rr.Pinhole(
width=frame.shape[1],
height=frame.shape[0],
focal_length=focal_length,
principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
image_plane_distance=depth.max(),
camera_xyz=rr.ViewCoordinates.FLU,
),
)
rr.log(
"world/camera/depth",
rr.DepthImage(depth, meter=1),
)
yield stream.read()
# clean up
if os.path.exists(path_to_video):
os.remove(path_to_video)
@spaces.GPU(duration=20)
def estimate_depth(frame):
image = transform(frame)
image = image.to(device)
prediction = model.infer(image)
depth = prediction["depth"].squeeze().detach().cpu().numpy()
focal_length = prediction["focallength_px"].item()
return depth, focal_length
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(variant="compact"):
video = gr.Video(interactive=True, include_audio=False, label="Video")
visualize = gr.Button("Visualize ML Depth Pro")
with gr.Column():
viewer = Rerun(
streaming=True,
)
visualize.click(run_rerun, inputs=[video], outputs=[viewer])
if __name__ == "__main__":
demo.launch()
|