|
import os |
|
import gradio as gr |
|
import torch |
|
import subprocess |
|
from PIL import Image |
|
from pathlib import Path |
|
import io |
|
import sys |
|
import traceback |
|
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
HF_DATASET_REPO = "roll-ai/FloVD-weights" |
|
|
|
def download_weights(): |
|
print("π Downloading model weights via huggingface_hub...") |
|
for rel_path in WEIGHT_FILES.values(): |
|
local_path = Path("ckpt") / rel_path |
|
if not local_path.exists(): |
|
print(f"π₯ Downloading {rel_path}") |
|
hf_hub_download( |
|
repo_id=HF_DATASET_REPO, |
|
repo_type="dataset", |
|
filename=rel_path, |
|
local_dir="ckpt", |
|
local_dir_use_symlinks=False, |
|
) |
|
else: |
|
print(f"β
Already exists: {local_path}") |
|
|
|
from inference.flovd_demo import generate_video |
|
|
|
def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name): |
|
|
|
log_buffer = io.StringIO() |
|
sys_stdout = sys.stdout |
|
sys.stdout = log_buffer |
|
|
|
video_path = None |
|
try: |
|
print("π Starting inference...") |
|
os.makedirs("input_images", exist_ok=True) |
|
image_path = "input_images/input_image.png" |
|
image.save(image_path) |
|
print(f"πΈ Saved input image to {image_path}") |
|
|
|
generate_video( |
|
prompt=prompt, |
|
image_path=image_path, |
|
fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt", |
|
omsm_path="./ckpt/OMSM", |
|
output_path="./outputs", |
|
num_frames=49, |
|
fps=16, |
|
width=None, |
|
height=None, |
|
seed=42, |
|
guidance_scale=6.0, |
|
dtype=torch.float16, |
|
controlnet_guidance_end=0.4, |
|
use_dynamic_cfg=False, |
|
pose_type=pose_type, |
|
speed=float(speed), |
|
use_flow_integration=use_flow_integration, |
|
cam_pose_name=cam_pose_name, |
|
depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth" |
|
) |
|
|
|
video_name = f"{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4" |
|
video_path = f"./outputs/generated_videos/{video_name}" |
|
print(f"β
Inference complete. Video saved to {video_path}") |
|
|
|
except Exception as e: |
|
print("π₯ Inference failed with exception:") |
|
traceback.print_exc() |
|
|
|
|
|
sys.stdout = sys_stdout |
|
logs = log_buffer.getvalue() |
|
log_buffer.close() |
|
|
|
return (video_path if video_path and os.path.exists(video_path) else None), logs |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## π₯ FloVD: Optical Flow + CogVideoX Video Generation") |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.") |
|
image = gr.Image(type="pil", label="Input Image") |
|
pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type") |
|
cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up") |
|
speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed") |
|
use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False) |
|
submit = gr.Button("Generate Video") |
|
with gr.Column(): |
|
output_video = gr.Video(label="Generated Video") |
|
output_logs = gr.Textbox(label="Logs", lines=20, interactive=False) |
|
|
|
submit.click( |
|
fn=run_inference, |
|
inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], |
|
outputs=[output_video, output_logs] |
|
) |
|
demo.launch(show_error=True) |
|
|