oxkitsune's picture
initial commit
eea4530
raw
history blame
2.93 kB
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import torch
import cv2
from pathlib import Path
import gradio as gr
from gradio_rerun import Rerun
@rr.thread_local_stream("rerun_example_ml_depth_pro")
def run_ml_depth_pro(model, transform, frames):
stream = rr.binary_stream()
assert model is not None, "Model is None"
assert transform is not None, "Transform is None"
assert frames is not None, "Frames is None"
blueprint = rrb.Blueprint(
rrb.Spatial3DView(origin="/"),
rrb.Horizontal(
rrb.Spatial2DView(origin="/world/camera/depth", title="Depth"),
rrb.Spatial2DView(origin="/world/camera/image", title="Image"),
),
collapse_panels=True,
)
rr.send_blueprint(blueprint)
for i, frame in enumerate(frames):
rr.set_time_sequence("frame", i)
rr.log("world/camera/image", rr.Image(frame))
image = transform(frame)
prediction = model.infer(image)
depth = prediction["depth"].squeeze().detach().cpu().numpy()
rr.log(
"world/camera",
rr.Pinhole(
width=frame.shape[1],
height=frame.shape[0],
focal_length=prediction["focallength_px"].item(),
principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
image_plane_distance=depth.max(),
),
)
rr.log(
"world/camera/depth",
# need 0.19 stable for this
# rr.DepthImage(depth, meter=1, depth_range=(depth.min(), depth.max())),
rr.DepthImage(depth, meter=1),
)
yield stream.read()
video_path = Path("hd-cat.mp4")
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms(device=device)
model.eval()
# Load video
frames = []
video = cv2.VideoCapture("hd-cat2.mp4")
while True:
read, frame = video.read()
if not read:
break
frame = cv2.resize(frame, (320, 240))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
with gr.Blocks() as demo:
with gr.Tab("Streaming"):
with gr.Row():
img = gr.Image(interactive=True, label="Image")
with gr.Column():
stream_ml_depth_pro = gr.Button("Stream Ml Depth Pro")
with gr.Row():
viewer = Rerun(
streaming=True,
panel_states={
"time": "collapsed",
"blueprint": "hidden",
"selection": "hidden",
},
)
stream_ml_depth_pro.click(
run_ml_depth_pro, inputs=[model, transform, frames], outputs=[viewer]
)
if __name__ == "__main__":
demo.launch()