Update app.py
Browse files
app.py
CHANGED
@@ -30,57 +30,53 @@ def download_weights():
|
|
30 |
print(f"β
Already exists: {save_path}")
|
31 |
|
32 |
download_weights()
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
from inference.flovd_demo import load_pipeline, generate_video
|
39 |
-
|
40 |
-
pipeline = load_pipeline(
|
41 |
-
fvsm_path="ckpt/FVSM/FloVD_FVSM_Controlnet.pt",
|
42 |
-
omsm_path="ckpt/OMSM",
|
43 |
-
depth_path="ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth",
|
44 |
-
device="cuda" if torch.cuda.is_available() else "cpu"
|
45 |
-
)
|
46 |
-
|
47 |
-
# =========================================
|
48 |
-
# 3. Inference Function
|
49 |
-
# =========================================
|
50 |
|
51 |
-
|
52 |
-
print("π Running inference...")
|
53 |
-
output_path = generate_video(
|
54 |
-
image=image,
|
55 |
prompt=prompt,
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
num_frames=49,
|
59 |
fps=16,
|
|
|
|
|
|
|
|
|
|
|
60 |
controlnet_guidance_end=0.4,
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
62 |
)
|
63 |
-
return
|
64 |
-
|
65 |
-
# =========================================
|
66 |
-
# 4. Gradio UI
|
67 |
-
# =========================================
|
68 |
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
fn=run_inference,
|
74 |
-
inputs=[
|
75 |
-
gr.Image(label="Input Image", type="pil"),
|
76 |
-
gr.Textbox(label="Text Prompt", value="A cinematic dolly zoom shot of a futuristic cityscape"),
|
77 |
-
gr.Textbox(label="Camera Trajectory File Path", value=example_cam),
|
78 |
-
],
|
79 |
-
outputs=gr.Video(label="Generated Video"),
|
80 |
-
title="FloVD-CogVideoX π ",
|
81 |
-
description="Upload an image, enter a text prompt and a camera trajectory file path to generate a controlled video using CogVideoX + optical flow.",
|
82 |
-
examples=[[example_image, "A beautiful sunrise over a mountain range", example_cam]]
|
83 |
-
)
|
84 |
|
85 |
-
|
86 |
-
demo.launch()
|
|
|
30 |
print(f"β
Already exists: {save_path}")
|
31 |
|
32 |
download_weights()
|
33 |
+
import gradio as gr
|
34 |
+
import torch
|
35 |
+
import os
|
36 |
+
from inference_script import generate_video # Assuming your script is saved as inference_script.py
|
37 |
|
38 |
+
def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
|
39 |
+
os.makedirs("input_images", exist_ok=True)
|
40 |
+
image_path = "input_images/input_image.png"
|
41 |
+
image.save(image_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
generate_video(
|
|
|
|
|
|
|
44 |
prompt=prompt,
|
45 |
+
image_path=image_path,
|
46 |
+
fvsm_path="./ckpt/FVSM", # Expected to be downloaded from HF dataset
|
47 |
+
omsm_path="./ckpt/OMSM", # Expected to be downloaded from HF dataset
|
48 |
+
output_path="./outputs",
|
49 |
num_frames=49,
|
50 |
fps=16,
|
51 |
+
width=None,
|
52 |
+
height=None,
|
53 |
+
seed=42,
|
54 |
+
guidance_scale=6.0,
|
55 |
+
dtype=torch.float16,
|
56 |
controlnet_guidance_end=0.4,
|
57 |
+
use_dynamic_cfg=False,
|
58 |
+
pose_type=pose_type,
|
59 |
+
speed=float(speed),
|
60 |
+
use_flow_integration=use_flow_integration,
|
61 |
+
cam_pose_name=cam_pose_name,
|
62 |
+
depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth"
|
63 |
)
|
64 |
+
return f"./outputs/generated_videos/{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4"
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
with gr.Blocks() as demo:
|
67 |
+
gr.Markdown("## π₯ FloVD: Optical Flow + CogVideoX Video Generation")
|
68 |
+
with gr.Row():
|
69 |
+
with gr.Column():
|
70 |
+
prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.")
|
71 |
+
image = gr.Image(type="pil", label="Input Image")
|
72 |
+
pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type")
|
73 |
+
cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up")
|
74 |
+
speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed")
|
75 |
+
use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False)
|
76 |
+
submit = gr.Button("Generate Video")
|
77 |
+
with gr.Column():
|
78 |
+
output_video = gr.Video(label="Generated Video")
|
79 |
|
80 |
+
submit.click(fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=output_video)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
demo.launch()
|
|