roll-ai commited on
Commit
3ca8d26
Β·
verified Β·
1 Parent(s): 3f4e721

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -20
app.py CHANGED
@@ -1,45 +1,50 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- import subprocess
5
  from PIL import Image
6
  from pathlib import Path
7
  import io
8
  import sys
9
  import traceback
 
10
 
11
  # =========================================
12
- # 1. Define Hugging Face weights and paths
13
  # =========================================
14
- from huggingface_hub import hf_hub_download
15
 
16
- HF_DATASET_REPO = "roll-ai/FloVD-weights" # dataset repo ID
 
17
  WEIGHT_FILES = {
18
- "FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
19
- "OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
20
- "OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
21
- "others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
22
  }
 
23
  def download_weights():
24
  print("πŸ”„ Downloading model weights via huggingface_hub...")
25
- for rel_path in WEIGHT_FILES.values():
26
- local_path = Path("ckpt") / rel_path
27
  if not local_path.exists():
28
- print(f"πŸ“₯ Downloading {rel_path}")
29
  hf_hub_download(
30
  repo_id=HF_DATASET_REPO,
31
  repo_type="dataset",
32
- filename=rel_path,
33
- local_dir="ckpt",
34
- local_dir_use_symlinks=False,
35
  )
36
  else:
37
  print(f"βœ… Already exists: {local_path}")
 
38
  download_weights()
 
 
 
 
 
39
  from inference.flovd_demo import generate_video
40
 
41
  def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
42
- # Redirect stdout to capture logs
43
  log_buffer = io.StringIO()
44
  sys_stdout = sys.stdout
45
  sys.stdout = log_buffer
@@ -78,20 +83,19 @@ def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pos
78
  video_path = f"./outputs/generated_videos/{video_name}"
79
  print(f"βœ… Inference complete. Video saved to {video_path}")
80
 
81
- except Exception as e:
82
  print("πŸ”₯ Inference failed with exception:")
83
  traceback.print_exc()
84
 
85
- # Restore stdout and return logs
86
  sys.stdout = sys_stdout
87
  logs = log_buffer.getvalue()
88
  log_buffer.close()
89
 
90
  return (video_path if video_path and os.path.exists(video_path) else None), logs
91
 
92
- # ========================
93
- # Gradio Interface
94
- # ========================
95
 
96
  with gr.Blocks() as demo:
97
  gr.Markdown("## πŸŽ₯ FloVD: Optical Flow + CogVideoX Video Generation")
@@ -113,4 +117,5 @@ with gr.Blocks() as demo:
113
  inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name],
114
  outputs=[output_video, output_logs]
115
  )
 
116
  demo.launch(show_error=True)
 
1
  import os
2
  import gradio as gr
3
  import torch
 
4
  from PIL import Image
5
  from pathlib import Path
6
  import io
7
  import sys
8
  import traceback
9
+ from huggingface_hub import hf_hub_download
10
 
11
  # =========================================
12
+ # 1. Define Hugging Face dataset + weights
13
  # =========================================
 
14
 
15
+ HF_DATASET_REPO = "roll-ai/FloVD-weights" # your dataset repo on HF
16
+
17
  WEIGHT_FILES = {
18
+ "ckpt/FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
19
+ "ckpt/OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
20
+ "ckpt/OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
21
+ "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
22
  }
23
+
24
  def download_weights():
25
  print("πŸ”„ Downloading model weights via huggingface_hub...")
26
+ for hf_path, local_rel_path in WEIGHT_FILES.items():
27
+ local_path = Path("ckpt") / local_rel_path
28
  if not local_path.exists():
29
+ print(f"πŸ“₯ Downloading {hf_path}")
30
  hf_hub_download(
31
  repo_id=HF_DATASET_REPO,
32
  repo_type="dataset",
33
+ filename=hf_path,
34
+ local_dir="ckpt"
 
35
  )
36
  else:
37
  print(f"βœ… Already exists: {local_path}")
38
+
39
  download_weights()
40
+
41
+ # =========================================
42
+ # 2. Import the FloVD generation pipeline
43
+ # =========================================
44
+
45
  from inference.flovd_demo import generate_video
46
 
47
  def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
 
48
  log_buffer = io.StringIO()
49
  sys_stdout = sys.stdout
50
  sys.stdout = log_buffer
 
83
  video_path = f"./outputs/generated_videos/{video_name}"
84
  print(f"βœ… Inference complete. Video saved to {video_path}")
85
 
86
+ except Exception:
87
  print("πŸ”₯ Inference failed with exception:")
88
  traceback.print_exc()
89
 
 
90
  sys.stdout = sys_stdout
91
  logs = log_buffer.getvalue()
92
  log_buffer.close()
93
 
94
  return (video_path if video_path and os.path.exists(video_path) else None), logs
95
 
96
+ # =========================================
97
+ # 3. Gradio App Interface
98
+ # =========================================
99
 
100
  with gr.Blocks() as demo:
101
  gr.Markdown("## πŸŽ₯ FloVD: Optical Flow + CogVideoX Video Generation")
 
117
  inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name],
118
  outputs=[output_video, output_logs]
119
  )
120
+
121
  demo.launch(show_error=True)