import os import gradio as gr import torch from PIL import Image from pathlib import Path import io import sys import traceback from huggingface_hub import hf_hub_download # ========================================= # 1. Define Hugging Face dataset + weights # ========================================= HF_DATASET_REPO = "roll-ai/FloVD-weights" # your dataset repo on HF WEIGHT_FILES = { "ckpt/FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlet.pt", "ckpt/OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors", "ckpt/OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors", "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth" } print("") print("Downloading model...", flush=True) def download_weights(): print("🔄 Downloading model weights via huggingface_hub...") for hf_path, local_rel_path in WEIGHT_FILES.items(): local_path = Path("ckpt") / local_rel_path if not local_path.exists(): print(f"📥 Downloading {hf_path}") hf_hub_download( repo_id=HF_DATASET_REPO, repo_type="dataset", filename=hf_path, local_dir="./" ) else: print(f"✅ Already exists: {local_path}") download_weights() def print_ckpt_structure(base_path="ckpt"): print(f"📂 Listing structure of: {base_path}", flush=True) for root, dirs, files in os.walk(base_path): level = root.replace(base_path, '').count(os.sep) indent = ' ' * 2 * level print(f"{indent}📁 {os.path.basename(root)}/", flush=True) sub_indent = ' ' * 2 * (level + 1) for f in files: print(f"{sub_indent}📄 {f}", flush=True) # Call it print_ckpt_structure() # ========================================= # 2. Import the FloVD generation pipeline # ========================================= from inference.flovd_demo import generate_video def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name): log_buffer = io.StringIO() sys_stdout = sys.stdout sys.stdout = log_buffer video_path = None try: print("🚀 Starting inference...") os.makedirs("input_images", exist_ok=True) image_path = "input_images/input_image.png" image.save(image_path) print(f"📸 Saved input image to {image_path}") generate_video( prompt=prompt, image_path=image_path, fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt", omsm_path="./ckpt/OMSM", output_path="./outputs", num_frames=49, fps=16, width=None, height=None, seed=42, guidance_scale=6.0, dtype=torch.float16, controlnet_guidance_end=0.4, use_dynamic_cfg=False, pose_type=pose_type, speed=float(speed), use_flow_integration=use_flow_integration, cam_pose_name=cam_pose_name, depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth" ) video_name = f"{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4" video_path = f"./outputs/generated_videos/{video_name}" print(f"✅ Inference complete. Video saved to {video_path}") except Exception: print("🔥 Inference failed with exception:") traceback.print_exc() sys.stdout = sys_stdout logs = log_buffer.getvalue() log_buffer.close() return (video_path if video_path and os.path.exists(video_path) else None), logs # ========================================= # 3. Gradio App Interface # ========================================= with gr.Blocks() as demo: gr.Markdown("## 🎥 FloVD: Optical Flow + CogVideoX Video Generation") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.") image = gr.Image(type="pil", label="Input Image") pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type") cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up") speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed") use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False) submit = gr.Button("Generate Video") with gr.Column(): output_video = gr.Video(label="Generated Video") output_logs = gr.Textbox(label="Logs", lines=20, interactive=False) submit.click( fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=[output_video, output_logs] ) demo.launch(show_error=True)