|
import os |
|
import gradio as gr |
|
import torch |
|
import subprocess |
|
from PIL import Image |
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
HF_DATASET_URL = "https://huggingface.co/datasets/roll-ai/FloVD-weights/resolve/main/ckpt" |
|
WEIGHT_FILES = { |
|
"FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt", |
|
"OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors", |
|
"OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors", |
|
"others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth" |
|
} |
|
|
|
def download_weights(): |
|
print("π Downloading model weights...") |
|
for rel_path in WEIGHT_FILES.values(): |
|
save_path = Path("ckpt") / rel_path |
|
if not save_path.exists(): |
|
save_path.parent.mkdir(parents=True, exist_ok=True) |
|
url = f"{HF_DATASET_URL}/{rel_path}" |
|
print(f"π₯ Downloading {url} β {save_path}") |
|
subprocess.run(["wget", "-q", "-O", str(save_path), url], check=True) |
|
else: |
|
print(f"β
Already exists: {save_path}") |
|
|
|
download_weights() |
|
import gradio as gr |
|
import torch |
|
import os |
|
from inference.flovd_demo import generate_video |
|
|
|
def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name): |
|
try: |
|
os.makedirs("input_images", exist_ok=True) |
|
image_path = "input_images/input_image.png" |
|
image.save(image_path) |
|
|
|
generate_video( |
|
prompt=prompt, |
|
image_path=image_path, |
|
fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt", |
|
omsm_path="./ckpt/OMSM", |
|
output_path="./outputs", |
|
num_frames=49, |
|
fps=16, |
|
width=None, |
|
height=None, |
|
seed=42, |
|
guidance_scale=6.0, |
|
dtype=torch.float16, |
|
controlnet_guidance_end=0.4, |
|
use_dynamic_cfg=False, |
|
pose_type=pose_type, |
|
speed=float(speed), |
|
use_flow_integration=use_flow_integration, |
|
cam_pose_name=cam_pose_name, |
|
depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth" |
|
) |
|
return f"./outputs/generated_videos/{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4" |
|
|
|
except Exception as e: |
|
print("π₯ Inference failed:") |
|
import traceback |
|
traceback.print_exc() |
|
return f"β οΈ Error during inference: {e}" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## π₯ FloVD: Optical Flow + CogVideoX Video Generation") |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.") |
|
image = gr.Image(type="pil", label="Input Image") |
|
pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type") |
|
cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up") |
|
speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed") |
|
use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False) |
|
submit = gr.Button("Generate Video") |
|
with gr.Column(): |
|
output_video = gr.Video(label="Generated Video") |
|
|
|
submit.click(fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=output_video) |
|
|
|
demo.launch() |
|
|