roll-ai commited on
Commit
b7776da
·
verified ·
1 Parent(s): 7a456b1

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -192
app.py DELETED
@@ -1,192 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import torch
4
- from PIL import Image
5
- from pathlib import Path
6
- import io
7
- import sys
8
- import traceback
9
- from huggingface_hub import hf_hub_download
10
- # For live system monitoring
11
- import psutil
12
- import GPUtil
13
-
14
- # =========================================
15
- # 1. Define Hugging Face dataset + weights
16
- # =========================================
17
-
18
- HF_DATASET_REPO = "roll-ai/FloVD-weights"
19
-
20
- WEIGHT_FILES = {
21
- "ckpt/FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
22
- "ckpt/OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
23
- "ckpt/OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
24
- "ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
25
- }
26
-
27
- print("\nDownloading model...", flush=True)
28
-
29
- def download_weights():
30
- print("🔄 Downloading model weights via huggingface_hub...")
31
- for hf_path, local_rel_path in WEIGHT_FILES.items():
32
- local_path = Path("ckpt") / local_rel_path
33
- if not local_path.exists():
34
- print(f"📥 Downloading {hf_path}")
35
- hf_hub_download(
36
- repo_id=HF_DATASET_REPO,
37
- repo_type="dataset",
38
- filename=hf_path,
39
- local_dir="./"
40
- )
41
- else:
42
- print(f"✅ Already exists: {local_path}")
43
-
44
- download_weights()
45
-
46
- def print_ckpt_structure(base_path="ckpt"):
47
- print(f"📂 Listing structure of: {base_path}", flush=True)
48
- for root, dirs, files in os.walk(base_path):
49
- level = root.replace(base_path, '').count(os.sep)
50
- indent = ' ' * 2 * level
51
- print(f"{indent}📁 {os.path.basename(root)}/", flush=True)
52
- sub_indent = ' ' * 2 * (level + 1)
53
- for f in files:
54
- print(f"{sub_indent}📄 {f}", flush=True)
55
-
56
- print_ckpt_structure()
57
-
58
- # =========================================
59
- # 2. Import FloVD generation pipeline
60
- # =========================================
61
-
62
- from inference.flovd_demo import generate_video
63
-
64
- @spaces.GPU
65
- def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
66
- log_buffer = io.StringIO()
67
- sys_stdout = sys.stdout
68
- sys.stdout = log_buffer
69
-
70
- video_path = None
71
- try:
72
- print("🚀 Starting inference...", flush=True)
73
- os.makedirs("input_images", exist_ok=True)
74
- image_path = "input_images/input_image.png"
75
-
76
- if not isinstance(image, Image.Image):
77
- image = Image.fromarray(image.astype("uint8"))
78
-
79
- image.save(image_path)
80
- print(f"📸 Saved input image to {image_path}", flush=True)
81
-
82
- generate_video(
83
- prompt=prompt,
84
- image_path=image_path,
85
- fvsm_path="./ckpt/FVSM/FloVD_FVSM_Controlnet.pt",
86
- omsm_path="./ckpt/OMSM",
87
- output_path="./outputs",
88
- num_frames=49,
89
- fps=16,
90
- width=None,
91
- height=None,
92
- seed=42,
93
- guidance_scale=6.0,
94
- dtype=torch.float16,
95
- controlnet_guidance_end=0.4,
96
- use_dynamic_cfg=False,
97
- pose_type=pose_type,
98
- speed=float(speed),
99
- use_flow_integration=use_flow_integration,
100
- cam_pose_name=cam_pose_name,
101
- depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth"
102
- )
103
-
104
- video_name = f"{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4"
105
- video_path = f"./outputs/generated_videos/{video_name}"
106
- print(f"✅ Inference complete. Video saved to {video_path}")
107
-
108
- except Exception:
109
- print("🔥 Inference failed with exception:")
110
- traceback.print_exc()
111
-
112
- sys.stdout = sys_stdout
113
- logs = log_buffer.getvalue()
114
- log_buffer.close()
115
-
116
- return (video_path if video_path and os.path.exists(video_path) else None), logs
117
-
118
-
119
- # =========================================
120
- # 3. Define FloVD Gradio Interface
121
- # =========================================
122
- with gr.Blocks() as video_tab:
123
- gr.Markdown("## 🎥 FloVD: Optical Flow + CogVideoX Video Generation")
124
-
125
- prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.")
126
- image = gr.Image(label="Input Image")
127
- pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type")
128
- speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Camera Speed")
129
- use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False)
130
- cam_pose_name = gr.Textbox(label="Camera Trajectory", placeholder="e.g., zoom_in, custom_motion, etc.", lines=1)
131
-
132
- generate_btn = gr.Button("🎬 Generate Video")
133
-
134
- video_output = gr.Video(label="Generated Video")
135
- log_output = gr.Textbox(label="Logs", lines=20, interactive=False)
136
-
137
- generate_btn.click(
138
- fn=run_inference,
139
- inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name],
140
- outputs=[video_output, log_output]
141
- )
142
-
143
- # =========================================
144
- # 4. Live System Monitor (Fixed)
145
- # =========================================
146
-
147
- def get_system_stats():
148
- cpu = psutil.cpu_percent()
149
- mem = psutil.virtual_memory()
150
- disk = psutil.disk_usage('/')
151
- try:
152
- gpus = GPUtil.getGPUs()
153
- gpu_info = "\n".join([
154
- f"GPU {i}: {gpu.name}, {gpu.memoryUsed}MB / {gpu.memoryTotal}MB, Util: {gpu.load * 100:.1f}%"
155
- for i, gpu in enumerate(gpus)
156
- ]) if gpus else "No GPU detected"
157
- except Exception as e:
158
- gpu_info = f"GPU info error: {e}"
159
-
160
- return (
161
- f"🧠 CPU Usage: {cpu}%\n"
162
- f"💾 RAM: {mem.used / 1e9:.2f} GB / {mem.total / 1e9:.2f} GB ({mem.percent}%)\n"
163
- f"🗄️ Disk: {disk.used / 1e9:.2f} GB / {disk.total / 1e9:.2f} GB ({disk.percent}%)\n"
164
- f"🎮 {gpu_info}"
165
- )
166
-
167
- with gr.Blocks() as monitor_tab:
168
- gr.Markdown("## 📊 Live System Resource Monitor")
169
- stats_box = gr.Textbox(label="Live Stats", lines=10, interactive=False)
170
-
171
- def update_stats():
172
- return gr.update(value=get_system_stats())
173
-
174
- stats_btn = gr.Button("🔄 Refresh Stats")
175
- stats_btn.click(fn=update_stats, outputs=stats_box)
176
-
177
- # =========================================
178
- # 5. Combine Tabs: FloVD + Monitor
179
- # =========================================
180
-
181
- with gr.Blocks() as app:
182
- with gr.Tab("🎥 Video Generator"):
183
- video_tab.render()
184
- with gr.Tab("📊 System Monitor"):
185
- monitor_tab.render()
186
-
187
- # =========================================
188
- # 6. Launch App
189
- # =========================================
190
-
191
- if __name__ == "__main__":
192
- app.launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True)