roll-ai commited on
Commit
302090a
ยท
verified ยท
1 Parent(s): b7776da

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ import io
7
+ import sys
8
+ import traceback
9
+ from huggingface_hub import hf_hub_download
10
+ # For live system monitoring
11
+ import psutil
12
+ import GPUtil
13
+
14
+ # =========================================
15
+ # 1. Define Hugging Face dataset + weights
16
+ # =========================================
17
+
18
+ HF_DATASET_REPO = "roll-ai/FloVD-weights"
19
+
20
+ WEIGHT_FILES = {
21
+ "ckpt/FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
22
+ "ckpt/OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
23
+ "ckpt/OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
24
+ "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
25
+ }
26
+
27
+ print("\nDownloading model...", flush=True)
28
+
29
+ def download_weights():
30
+ print("๐Ÿ”„ Downloading model weights via huggingface_hub...")
31
+ for hf_path, local_rel_path in WEIGHT_FILES.items():
32
+ local_path = Path("ckpt") / local_rel_path
33
+ if not local_path.exists():
34
+ print(f"๐Ÿ“ฅ Downloading {hf_path}")
35
+ hf_hub_download(
36
+ repo_id=HF_DATASET_REPO,
37
+ repo_type="dataset",
38
+ filename=hf_path,
39
+ local_dir="./"
40
+ )
41
+ else:
42
+ print(f"โœ… Already exists: {local_path}")
43
+
44
+ download_weights()
45
+
46
+ def print_ckpt_structure(base_path="ckpt"):
47
+ print(f"๐Ÿ“‚ Listing structure of: {base_path}", flush=True)
48
+ for root, dirs, files in os.walk(base_path):
49
+ level = root.replace(base_path, '').count(os.sep)
50
+ indent = ' ' * 2 * level
51
+ print(f"{indent}๐Ÿ“ {os.path.basename(root)}/", flush=True)
52
+ sub_indent = ' ' * 2 * (level + 1)
53
+ for f in files:
54
+ print(f"{sub_indent}๐Ÿ“„ {f}", flush=True)
55
+
56
+ print_ckpt_structure()
57
+
58
+ # =========================================
59
+ # 2. Import FloVD generation pipeline
60
+ # =========================================
61
+
62
+ from inference.flovd_demo import generate_video
63
+
64
+ @spaces.GPU
65
+ def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
66
+ log_buffer = io.StringIO()
67
+ sys_stdout = sys.stdout
68
+ sys.stdout = log_buffer
69
+
70
+ video_path = None
71
+ try:
72
+ print("๐Ÿš€ Starting inference...", flush=True)
73
+ os.makedirs("input_images", exist_ok=True)
74
+ image_path = "input_images/input_image.png"
75
+
76
+ if not isinstance(image, Image.Image):
77
+ image = Image.fromarray(image.astype("uint8"))
78
+
79
+ image.save(image_path)
80
+ print(f"๐Ÿ“ธ Saved input image to {image_path}", flush=True)
81
+
82
+ generate_video(
83
+ prompt=prompt,
84
+ image_path=image_path,
85
+ fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt",
86
+ omsm_path="./ckpt/OMSM",
87
+ output_path="./outputs",
88
+ num_frames=49,
89
+ fps=16,
90
+ width=None,
91
+ height=None,
92
+ seed=42,
93
+ guidance_scale=6.0,
94
+ dtype=torch.float16,
95
+ controlnet_guidance_end=0.4,
96
+ use_dynamic_cfg=False,
97
+ pose_type=pose_type,
98
+ speed=float(speed),
99
+ use_flow_integration=use_flow_integration,
100
+ cam_pose_name=cam_pose_name,
101
+ depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth"
102
+ )
103
+
104
+ video_name = f"{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4"
105
+ video_path = f"./outputs/generated_videos/{video_name}"
106
+ print(f"โœ… Inference complete. Video saved to {video_path}")
107
+
108
+ except Exception:
109
+ print("๐Ÿ”ฅ Inference failed with exception:")
110
+ traceback.print_exc()
111
+
112
+ sys.stdout = sys_stdout
113
+ logs = log_buffer.getvalue()
114
+ log_buffer.close()
115
+
116
+ return (video_path if video_path and os.path.exists(video_path) else None), logs
117
+
118
+
119
+ # =========================================
120
+ # 3. Define FloVD Gradio Interface
121
+ # =========================================
122
+ with gr.Blocks() as video_tab:
123
+ gr.Markdown("## ๐ŸŽฅ FloVD: Optical Flow + CogVideoX Video Generation")
124
+
125
+ prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.")
126
+ image = gr.Image(label="Input Image")
127
+ pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type")
128
+ speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Camera Speed")
129
+ use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False)
130
+ cam_pose_name = gr.Textbox(label="Camera Trajectory", placeholder="e.g., zoom_in, custom_motion, etc.", lines=1)
131
+
132
+ generate_btn = gr.Button("๐ŸŽฌ Generate Video")
133
+
134
+ video_output = gr.Video(label="Generated Video")
135
+ log_output = gr.Textbox(label="Logs", lines=20, interactive=False)
136
+
137
+ generate_btn.click(
138
+ fn=run_inference,
139
+ inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name],
140
+ outputs=[video_output, log_output]
141
+ )
142
+
143
+ # =========================================
144
+ # 4. Live System Monitor (Fixed)
145
+ # =========================================
146
+
147
+ def get_system_stats():
148
+ cpu = psutil.cpu_percent()
149
+ mem = psutil.virtual_memory()
150
+ disk = psutil.disk_usage('/')
151
+ try:
152
+ gpus = GPUtil.getGPUs()
153
+ gpu_info = "\n".join([
154
+ f"GPU {i}: {gpu.name}, {gpu.memoryUsed}MB / {gpu.memoryTotal}MB, Util: {gpu.load * 100:.1f}%"
155
+ for i, gpu in enumerate(gpus)
156
+ ]) if gpus else "No GPU detected"
157
+ except Exception as e:
158
+ gpu_info = f"GPU info error: {e}"
159
+
160
+ return (
161
+ f"๐Ÿง  CPU Usage: {cpu}%\n"
162
+ f"๐Ÿ’พ RAM: {mem.used / 1e9:.2f} GB / {mem.total / 1e9:.2f} GB ({mem.percent}%)\n"
163
+ f"๐Ÿ—„๏ธ Disk: {disk.used / 1e9:.2f} GB / {disk.total / 1e9:.2f} GB ({disk.percent}%)\n"
164
+ f"๐ŸŽฎ {gpu_info}"
165
+ )
166
+
167
+ with gr.Blocks() as monitor_tab:
168
+ gr.Markdown("## ๐Ÿ“Š Live System Resource Monitor")
169
+ stats_box = gr.Textbox(label="Live Stats", lines=10, interactive=False)
170
+
171
+ def update_stats():
172
+ return gr.update(value=get_system_stats())
173
+
174
+ stats_btn = gr.Button("๐Ÿ”„ Refresh Stats")
175
+ stats_btn.click(fn=update_stats, outputs=stats_box)
176
+
177
+ # =========================================
178
+ # 5. Combine Tabs: FloVD + Monitor
179
+ # =========================================
180
+
181
+ with gr.Blocks() as app:
182
+ with gr.Tab("๐ŸŽฅ Video Generator"):
183
+ video_tab.render()
184
+ with gr.Tab("๐Ÿ“Š System Monitor"):
185
+ monitor_tab.render()
186
+
187
+ # =========================================
188
+ # 6. Launch App
189
+ # =========================================
190
+
191
+ if __name__ == "__main__":
192
+ app.launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True)