# MIT License # Copyright (c) 2024 Jiahao Shao # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import functools import os import zipfile import tempfile from io import BytesIO import spaces import gradio as gr import numpy as np import torch as torch import torch.nn.functional as F import xformers from PIL import Image from tqdm import tqdm import mediapy as media from huggingface_hub import login from gradio_patches.examples import Examples from chronodepth.unet_chronodepth import DiffusersUNetSpatioTemporalConditionModelChronodepth from chronodepth.chronodepth_pipeline import ChronoDepthPipeline from chronodepth.video_utils import resize_max_res, colorize_video_depth MAX_FRAME=60 default_seed = 2024 default_num_inference_steps = 5 default_n_tokens = 10 default_chunk_size = 5 default_video_processing_resolution = 768 default_decode_chunk_size = 8 @torch.no_grad() def run_pipeline(pipe, video_rgb, generator, device): """ Run the pipe on the input video. args: pipe: ChronoDepthPipeline object video_rgb: input video, torch.Tensor, shape [T, H, W, 3], range [0, 255] generator: torch.Generator returns: video_depth_pred: predicted depth, torch.Tensor, shape [T, H, W], range [0, 1] """ if isinstance(video_rgb, torch.Tensor): video_rgb = video_rgb.cpu().numpy() original_height = video_rgb.shape[1] original_width = video_rgb.shape[2] # resize the video to the max resolution video_rgb = resize_max_res(video_rgb, default_video_processing_resolution) video_rgb = video_rgb.astype(np.float32) / 255.0 pipe_out = pipe( video_rgb, num_inference_steps=default_num_inference_steps, decode_chunk_size=default_decode_chunk_size, motion_bucket_id=127, fps=7, noise_aug_strength=0.0, generator=generator, infer_mode="ours", sigma_epsilon=-4, ) depth_frames_pred = pipe_out.frames depth_frames_pred = torch.from_numpy(depth_frames_pred).to(device) depth_frames_pred = F.interpolate(depth_frames_pred, size=(original_height, original_width), mode="bilinear", align_corners=False) depth_frames_pred = depth_frames_pred.clamp(0, 1) depth_frames_pred = depth_frames_pred.squeeze(1) return depth_frames_pred def process_video( pipe, path_input, num_inference_steps=default_num_inference_steps, out_max_frames=MAX_FRAME, progress=gr.Progress(), ): if path_input is None: raise gr.Error( "Missing video in the first pane: upload a file or use one from the gallery below." ) name_base, name_ext = os.path.splitext(os.path.basename(path_input)) print(f"Processing video {name_base}{name_ext}") path_output_dir = tempfile.mkdtemp() path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip") generator = torch.Generator(device=pipe.device).manual_seed(default_seed) import time start_time = time.time() zipf = None try: # -------------------- data -------------------- video_name = path_input.split('/')[-1].split('.')[0] video_data = media.read_video(path_input) fps = video_data.metadata.fps video_length = len(video_data) video_rgb = np.array(video_data) duration_sec = video_length / fps out_duration_sec = out_max_frames / fps if duration_sec > out_duration_sec: gr.Warning( f"Only the first ~{int(out_duration_sec)} seconds will be processed; " f"use alternative setups such as ChronoDepth on github for full processing" ) video_rgb = video_rgb[:out_max_frames] zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED) # -------------------- Inference and saving -------------------- depth_pred = run_pipeline(pipe, video_rgb, generator, pipe.device) # range [0, 1] depth_pred = depth_pred.cpu().numpy() depth_colored_pred = colorize_video_depth(depth_pred) # range [0, 1] -> [0, 255] # -------------------- Save results -------------------- for i in tqdm(range(len(depth_pred))): archive_path = os.path.join( f"{name_base}_depth_16bit", f"{i:05d}.png" ) img_byte_arr = BytesIO() depth_16bit = Image.fromarray((depth_pred[i] * 65535.0).astype(np.uint16)) depth_16bit.save(img_byte_arr, format="png") img_byte_arr.seek(0) zipf.writestr(archive_path, img_byte_arr.read()) # Export to video media.write_video(path_out_vis, depth_colored_pred, fps=fps) finally: if zipf is not None: zipf.close() end_time = time.time() print(f"Processing time: {end_time - start_time} seconds") return ( path_out_vis, [path_out_vis, path_out_16bit], ) def run_demo_server(pipe): process_pipe_video = spaces.GPU( functools.partial(process_video, pipe), duration=70 ) os.environ["GRADIO_ALLOW_FLAGGING"] = "never" with gr.Blocks( analytics_enabled=False, title="ChronoDepth Video Depth Estimation", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } """, ) as demo: gr.HTML( """

⏰ChronoDepth: Learning Temporally Consistent Video Depth from Video Diffusion Priors

badge-github-stars

ChronoDepth is the state-of-the-art video depth estimator for streaming videos in the wild.

PS: The maximum video length is limited to 60 frames for the demo. To process longer videos, please use the ChronoDepth on github.

""" ) with gr.Row(): with gr.Column(): video_input = gr.Video( label="Input Video", sources=["upload"], ) with gr.Row(): video_submit_btn = gr.Button( value="Compute Depth", variant="primary" ) video_reset_btn = gr.Button(value="Reset") with gr.Column(): video_output_video = gr.Video( label="Output video depth (red-near, blue-far)", interactive=False, ) video_output_files = gr.Files( label="Depth outputs", elem_id="download", interactive=False, ) Examples( fn=process_pipe_video, examples=[ ["files/elephant.mp4"], ["files/kitti360_seq_0000.mp4"], ], inputs=[video_input], outputs=[video_output_video, video_output_files], cache_examples=True, directory_name="examples_video", ) video_submit_btn.click( fn=process_pipe_video, inputs=[video_input], outputs=[video_output_video, video_output_files], concurrency_limit=1, ) video_reset_btn.click( fn=lambda: (None, None, None), inputs=[], outputs=[video_input, video_output_video], concurrency_limit=1, ) demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, ) def main(): CHECKPOINT = "jhshao/ChronoDepth-v1" if "HF_TOKEN_LOGIN" in os.environ: login(token=os.environ["HF_TOKEN_LOGIN"]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Running on device: {device}") # -------------------- Model -------------------- unet = DiffusersUNetSpatioTemporalConditionModelChronodepth.from_pretrained( CHECKPOINT, low_cpu_mem_usage=True, torch_dtype=torch.float16, ) pipe = ChronoDepthPipeline.from_pretrained( "stabilityai/stable-video-diffusion-img2vid-xt", unet=unet, torch_dtype=torch.float16, variant="fp16", ) pipe.n_tokens = default_n_tokens pipe.chunk_size = default_chunk_size try: pipe.enable_xformers_memory_efficient_attention() except: pass # run without xformers pipe = pipe.to(device) run_demo_server(pipe) if __name__ == "__main__": main()