Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import os | |
import numpy as np | |
import spaces | |
import gradio as gr | |
import torch | |
from diffusers.training_utils import set_seed | |
from diffusers import AutoencoderKLTemporalDecoder | |
from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline | |
from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter | |
import uuid | |
import random | |
from huggingface_hub import hf_hub_download | |
from normalcrafter.utils import read_video_frames, vis_sequence_normal, save_video | |
examples = [ | |
["examples/example_01.mp4", 1024, -1, -1], | |
["examples/example_02.mp4", 1024, -1, -1], | |
["examples/example_03.mp4", 1024, -1, -1], | |
["examples/example_04.mp4", 1024, -1, -1], | |
# ["examples/example_05.mp4", 1024, -1, -1], | |
# ["examples/example_06.mp4", 1024, -1, -1], | |
] | |
pretrained_model_name_or_path = "Yanrui95/NormalCrafter" | |
weight_dtype = torch.float16 | |
unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="unet", | |
low_cpu_mem_usage=True, | |
) | |
vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
pretrained_model_name_or_path, subfolder="vae") | |
vae.to(dtype=weight_dtype) | |
unet.to(dtype=weight_dtype) | |
pipe = NormalCrafterPipeline.from_pretrained( | |
"stabilityai/stable-video-diffusion-img2vid-xt", | |
unet=unet, | |
vae=vae, | |
torch_dtype=weight_dtype, | |
variant="fp16", | |
) | |
pipe.to("cuda") | |
def infer_depth( | |
video: str, | |
max_res: int = 1024, | |
process_length: int = -1, | |
target_fps: int = -1, | |
# | |
save_folder: str = "./demo_output", | |
window_size: int = 14, | |
time_step_size: int = 10, | |
decode_chunk_size: int = 7, | |
seed: int = 42, | |
save_npz: bool = False, | |
): | |
set_seed(seed) | |
pipe.enable_xformers_memory_efficient_attention() | |
frames, target_fps = read_video_frames(video, process_length, target_fps, max_res) | |
# inference the depth map using the DepthCrafter pipeline | |
with torch.inference_mode(): | |
res = pipe( | |
frames, | |
decode_chunk_size=decode_chunk_size, | |
time_step_size=time_step_size, | |
window_size=window_size, | |
).frames[0] | |
# visualize the depth map and save the results | |
vis = vis_sequence_normal(res) | |
# save the depth map and visualization with the target FPS | |
save_path = os.path.join(save_folder, os.path.splitext(os.path.basename(video))[0]) | |
print(f"==> saving results to {save_path}") | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
if save_npz: | |
np.savez_compressed(save_path + ".npz", normal=res) | |
save_video(vis, save_path + "_vis.mp4", fps=target_fps) | |
save_video(frames, save_path + "_input.mp4", fps=target_fps) | |
# clear the cache for the next video | |
gc.collect() | |
torch.cuda.empty_cache() | |
return [ | |
save_path + "_input.mp4", | |
save_path + "_vis.mp4", | |
] | |
def construct_demo(): | |
with gr.Blocks(analytics_enabled=False) as depthcrafter_iface: | |
gr.Markdown( | |
""" | |
<div align='center'> <h1> NormalCrafter: Learning Temporally Consistent Video Normal from Video Diffusion Priors </span> </h1> \ | |
<a style='font-size:18px;color: #000000'>If you find NormalCrafter useful, please help ⭐ the </a>\ | |
<a style='font-size:18px;color: #FF5DB0' href='https://github.com/Binyr/NormalCrafter'>[Github Repo]</a>\ | |
<a style='font-size:18px;color: #000000'>, which is important to Open-Source projects. Thanks!</a>\ | |
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/xxx'> [ArXiv] </a>\ | |
<a style='font-size:18px;color: #000000' href='https://binyr.github.io/NormalCrafter/'> [Project Page] </a> </div> | |
""" | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
input_video = gr.Video(label="Input Video") | |
# with gr.Tab(label="Output"): | |
with gr.Column(scale=2): | |
with gr.Row(equal_height=True): | |
output_video_1 = gr.Video( | |
label="Preprocessed video", | |
interactive=False, | |
autoplay=True, | |
loop=True, | |
show_share_button=True, | |
scale=5, | |
) | |
output_video_2 = gr.Video( | |
label="Generated Depth Video", | |
interactive=False, | |
autoplay=True, | |
loop=True, | |
show_share_button=True, | |
scale=5, | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
with gr.Row(equal_height=False): | |
with gr.Accordion("Advanced Settings", open=False): | |
max_res = gr.Slider( | |
label="max resolution", | |
minimum=512, | |
maximum=1024, | |
value=1024, | |
step=64, | |
) | |
process_length = gr.Slider( | |
label="process length", | |
minimum=-1, | |
maximum=280, | |
value=60, | |
step=1, | |
) | |
process_target_fps = gr.Slider( | |
label="target FPS", | |
minimum=-1, | |
maximum=30, | |
value=15, | |
step=1, | |
) | |
generate_btn = gr.Button("Generate") | |
with gr.Column(scale=2): | |
pass | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
input_video, | |
max_res, | |
process_length, | |
process_target_fps, | |
], | |
outputs=[output_video_1, output_video_2], | |
fn=infer_depth, | |
cache_examples="lazy", | |
) | |
# gr.Markdown( | |
# """ | |
# <span style='font-size:18px;color: #E7CCCC'>Note: | |
# For time quota consideration, we set the default parameters to be more efficient here, | |
# with a trade-off of shorter video length and slightly lower quality. | |
# You may adjust the parameters according to our | |
# <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Tencent/DepthCrafter'>[Github Repo]</a> | |
# for better results if you have enough time quota. | |
# </span> | |
# """ | |
# ) | |
generate_btn.click( | |
fn=infer_depth, | |
inputs=[ | |
input_video, | |
max_res, | |
process_length, | |
process_target_fps, | |
], | |
outputs=[output_video_1, output_video_2], | |
) | |
return depthcrafter_iface | |
if __name__ == "__main__": | |
demo = construct_demo() | |
demo.queue() | |
# demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False) | |
demo.launch(share=True) | |