import os import gradio as gr import torch import subprocess from PIL import Image from pathlib import Path import io import sys import traceback # ========================================= # 1. Define Hugging Face weights and paths # ========================================= from huggingface_hub import hf_hub_download HF_DATASET_REPO = "roll-ai/FloVD-weights" # dataset repo ID def download_weights(): print("🔄 Downloading model weights via huggingface_hub...") for rel_path in WEIGHT_FILES.values(): local_path = Path("ckpt") / rel_path if not local_path.exists(): print(f"📥 Downloading {rel_path}") hf_hub_download( repo_id=HF_DATASET_REPO, repo_type="dataset", filename=rel_path, local_dir="ckpt", local_dir_use_symlinks=False, ) else: print(f"✅ Already exists: {local_path}") from inference.flovd_demo import generate_video def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name): # Redirect stdout to capture logs log_buffer = io.StringIO() sys_stdout = sys.stdout sys.stdout = log_buffer video_path = None try: print("🚀 Starting inference...") os.makedirs("input_images", exist_ok=True) image_path = "input_images/input_image.png" image.save(image_path) print(f"📸 Saved input image to {image_path}") generate_video( prompt=prompt, image_path=image_path, fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt", omsm_path="./ckpt/OMSM", output_path="./outputs", num_frames=49, fps=16, width=None, height=None, seed=42, guidance_scale=6.0, dtype=torch.float16, controlnet_guidance_end=0.4, use_dynamic_cfg=False, pose_type=pose_type, speed=float(speed), use_flow_integration=use_flow_integration, cam_pose_name=cam_pose_name, depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth" ) video_name = f"{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4" video_path = f"./outputs/generated_videos/{video_name}" print(f"✅ Inference complete. Video saved to {video_path}") except Exception as e: print("🔥 Inference failed with exception:") traceback.print_exc() # Restore stdout and return logs sys.stdout = sys_stdout logs = log_buffer.getvalue() log_buffer.close() return (video_path if video_path and os.path.exists(video_path) else None), logs # ======================== # Gradio Interface # ======================== with gr.Blocks() as demo: gr.Markdown("## 🎥 FloVD: Optical Flow + CogVideoX Video Generation") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.") image = gr.Image(type="pil", label="Input Image") pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type") cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up") speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed") use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False) submit = gr.Button("Generate Video") with gr.Column(): output_video = gr.Video(label="Generated Video") output_logs = gr.Textbox(label="Logs", lines=20, interactive=False) submit.click( fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=[output_video, output_logs] ) demo.launch(show_error=True)