import os import gradio as gr import torch import subprocess from PIL import Image from pathlib import Path # ========================================= # 1. Define Hugging Face weights and paths # ========================================= HF_DATASET_URL = "https://huggingface.co/datasets/roll-ai/FloVD-weights/resolve/main/ckpt" WEIGHT_FILES = { "FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt", "OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors", "OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors", "others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth" } def download_weights(): print("🔄 Downloading model weights...") for rel_path in WEIGHT_FILES.values(): save_path = Path("ckpt") / rel_path if not save_path.exists(): save_path.parent.mkdir(parents=True, exist_ok=True) url = f"{HF_DATASET_URL}/{rel_path}" print(f"📥 Downloading {url} → {save_path}") subprocess.run(["wget", "-q", "-O", str(save_path), url], check=True) else: print(f"✅ Already exists: {save_path}") download_weights() import gradio as gr import torch import os from inference.flovd_demo import generate_video def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name): try: os.makedirs("input_images", exist_ok=True) image_path = "input_images/input_image.png" image.save(image_path) generate_video( prompt=prompt, image_path=image_path, fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt", omsm_path="./ckpt/OMSM", output_path="./outputs", num_frames=49, fps=16, width=None, height=None, seed=42, guidance_scale=6.0, dtype=torch.float16, controlnet_guidance_end=0.4, use_dynamic_cfg=False, pose_type=pose_type, speed=float(speed), use_flow_integration=use_flow_integration, cam_pose_name=cam_pose_name, depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth" ) return f"./outputs/generated_videos/{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4" except Exception as e: print("🔥 Inference failed:") import traceback traceback.print_exc() return f"⚠️ Error during inference: {e}" with gr.Blocks() as demo: gr.Markdown("## 🎥 FloVD: Optical Flow + CogVideoX Video Generation") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.") image = gr.Image(type="pil", label="Input Image") pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type") cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up") speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed") use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False) submit = gr.Button("Generate Video") with gr.Column(): output_video = gr.Video(label="Generated Video") submit.click(fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=output_video) demo.launch()