Update app.py
Browse files
app.py
CHANGED
@@ -1,45 +1,50 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
-
import subprocess
|
5 |
from PIL import Image
|
6 |
from pathlib import Path
|
7 |
import io
|
8 |
import sys
|
9 |
import traceback
|
|
|
10 |
|
11 |
# =========================================
|
12 |
-
# 1. Define Hugging Face
|
13 |
# =========================================
|
14 |
-
from huggingface_hub import hf_hub_download
|
15 |
|
16 |
-
HF_DATASET_REPO = "roll-ai/FloVD-weights" # dataset repo
|
|
|
17 |
WEIGHT_FILES = {
|
18 |
-
"FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
|
19 |
-
"OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
|
20 |
-
"OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
|
21 |
-
"others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
|
22 |
}
|
|
|
23 |
def download_weights():
|
24 |
print("π Downloading model weights via huggingface_hub...")
|
25 |
-
for
|
26 |
-
local_path = Path("ckpt") /
|
27 |
if not local_path.exists():
|
28 |
-
print(f"π₯ Downloading {
|
29 |
hf_hub_download(
|
30 |
repo_id=HF_DATASET_REPO,
|
31 |
repo_type="dataset",
|
32 |
-
filename=
|
33 |
-
local_dir="ckpt"
|
34 |
-
local_dir_use_symlinks=False,
|
35 |
)
|
36 |
else:
|
37 |
print(f"β
Already exists: {local_path}")
|
|
|
38 |
download_weights()
|
|
|
|
|
|
|
|
|
|
|
39 |
from inference.flovd_demo import generate_video
|
40 |
|
41 |
def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
|
42 |
-
# Redirect stdout to capture logs
|
43 |
log_buffer = io.StringIO()
|
44 |
sys_stdout = sys.stdout
|
45 |
sys.stdout = log_buffer
|
@@ -78,20 +83,19 @@ def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pos
|
|
78 |
video_path = f"./outputs/generated_videos/{video_name}"
|
79 |
print(f"β
Inference complete. Video saved to {video_path}")
|
80 |
|
81 |
-
except Exception
|
82 |
print("π₯ Inference failed with exception:")
|
83 |
traceback.print_exc()
|
84 |
|
85 |
-
# Restore stdout and return logs
|
86 |
sys.stdout = sys_stdout
|
87 |
logs = log_buffer.getvalue()
|
88 |
log_buffer.close()
|
89 |
|
90 |
return (video_path if video_path and os.path.exists(video_path) else None), logs
|
91 |
|
92 |
-
#
|
93 |
-
# Gradio Interface
|
94 |
-
#
|
95 |
|
96 |
with gr.Blocks() as demo:
|
97 |
gr.Markdown("## π₯ FloVD: Optical Flow + CogVideoX Video Generation")
|
@@ -113,4 +117,5 @@ with gr.Blocks() as demo:
|
|
113 |
inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name],
|
114 |
outputs=[output_video, output_logs]
|
115 |
)
|
|
|
116 |
demo.launch(show_error=True)
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import torch
|
|
|
4 |
from PIL import Image
|
5 |
from pathlib import Path
|
6 |
import io
|
7 |
import sys
|
8 |
import traceback
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
|
11 |
# =========================================
|
12 |
+
# 1. Define Hugging Face dataset + weights
|
13 |
# =========================================
|
|
|
14 |
|
15 |
+
HF_DATASET_REPO = "roll-ai/FloVD-weights" # your dataset repo on HF
|
16 |
+
|
17 |
WEIGHT_FILES = {
|
18 |
+
"ckpt/FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
|
19 |
+
"ckpt/OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
|
20 |
+
"ckpt/OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
|
21 |
+
"ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
|
22 |
}
|
23 |
+
|
24 |
def download_weights():
|
25 |
print("π Downloading model weights via huggingface_hub...")
|
26 |
+
for hf_path, local_rel_path in WEIGHT_FILES.items():
|
27 |
+
local_path = Path("ckpt") / local_rel_path
|
28 |
if not local_path.exists():
|
29 |
+
print(f"π₯ Downloading {hf_path}")
|
30 |
hf_hub_download(
|
31 |
repo_id=HF_DATASET_REPO,
|
32 |
repo_type="dataset",
|
33 |
+
filename=hf_path,
|
34 |
+
local_dir="ckpt"
|
|
|
35 |
)
|
36 |
else:
|
37 |
print(f"β
Already exists: {local_path}")
|
38 |
+
|
39 |
download_weights()
|
40 |
+
|
41 |
+
# =========================================
|
42 |
+
# 2. Import the FloVD generation pipeline
|
43 |
+
# =========================================
|
44 |
+
|
45 |
from inference.flovd_demo import generate_video
|
46 |
|
47 |
def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
|
|
|
48 |
log_buffer = io.StringIO()
|
49 |
sys_stdout = sys.stdout
|
50 |
sys.stdout = log_buffer
|
|
|
83 |
video_path = f"./outputs/generated_videos/{video_name}"
|
84 |
print(f"β
Inference complete. Video saved to {video_path}")
|
85 |
|
86 |
+
except Exception:
|
87 |
print("π₯ Inference failed with exception:")
|
88 |
traceback.print_exc()
|
89 |
|
|
|
90 |
sys.stdout = sys_stdout
|
91 |
logs = log_buffer.getvalue()
|
92 |
log_buffer.close()
|
93 |
|
94 |
return (video_path if video_path and os.path.exists(video_path) else None), logs
|
95 |
|
96 |
+
# =========================================
|
97 |
+
# 3. Gradio App Interface
|
98 |
+
# =========================================
|
99 |
|
100 |
with gr.Blocks() as demo:
|
101 |
gr.Markdown("## π₯ FloVD: Optical Flow + CogVideoX Video Generation")
|
|
|
117 |
inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name],
|
118 |
outputs=[output_video, output_logs]
|
119 |
)
|
120 |
+
|
121 |
demo.launch(show_error=True)
|