roll-ai commited on
Commit
7e06d6b
Β·
verified Β·
1 Parent(s): 4e7b4da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import subprocess
5
+ from PIL import Image
6
+ from pathlib import Path
7
+
8
+ # =========================================
9
+ # 1. Define Hugging Face weights and paths
10
+ # =========================================
11
+
12
+ HF_DATASET_URL = "https://huggingface.co/datasets/roll-ai/FloVD-weights/resolve/main/ckpt"
13
+ WEIGHT_FILES = {
14
+ "FVSM/FloVD_FVSM_Controlnet.pt": "FVSM/FloVD_FVSM_Controlnet.pt",
15
+ "OMSM/selected_blocks.safetensors": "OMSM/selected_blocks.safetensors",
16
+ "OMSM/pytorch_lora_weights.safetensors": "OMSM/pytorch_lora_weights.safetensors",
17
+ "others/depth_anything_v2_metric_hypersim_vitb.pth": "others/depth_anything_v2_metric_hypersim_vitb.pth"
18
+ }
19
+
20
+ def download_weights():
21
+ print("πŸ”„ Downloading model weights...")
22
+ for rel_path in WEIGHT_FILES.values():
23
+ save_path = Path("ckpt") / rel_path
24
+ if not save_path.exists():
25
+ save_path.parent.mkdir(parents=True, exist_ok=True)
26
+ url = f"{HF_DATASET_URL}/{rel_path}"
27
+ print(f"πŸ“₯ Downloading {url} β†’ {save_path}")
28
+ subprocess.run(["wget", "-q", "-O", str(save_path), url], check=True)
29
+ else:
30
+ print(f"βœ… Already exists: {save_path}")
31
+
32
+ download_weights()
33
+
34
+ # =========================================
35
+ # 2. Import and load FloVD pipeline
36
+ # =========================================
37
+
38
+ from inference.flovd_demo import load_pipeline, generate_video
39
+
40
+ pipeline = load_pipeline(
41
+ fvsm_path="ckpt/FVSM/FloVD_FVSM_Controlnet.pt",
42
+ omsm_path="ckpt/OMSM",
43
+ depth_path="ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth",
44
+ device="cuda" if torch.cuda.is_available() else "cpu"
45
+ )
46
+
47
+ # =========================================
48
+ # 3. Inference Function
49
+ # =========================================
50
+
51
+ def run_inference(image: Image.Image, prompt: str, cam_traj_path: str):
52
+ print("πŸš€ Running inference...")
53
+ output_path = generate_video(
54
+ image=image,
55
+ prompt=prompt,
56
+ cam_traj=cam_traj_path,
57
+ pipeline=pipeline,
58
+ num_frames=49,
59
+ fps=16,
60
+ controlnet_guidance_end=0.4,
61
+ flow_scale=(60, 36)
62
+ )
63
+ return output_path
64
+
65
+ # =========================================
66
+ # 4. Gradio UI
67
+ # =========================================
68
+
69
+ example_image = "assets/manual_poses/example_image.jpg"
70
+ example_cam = "assets/cam_trajectory/dolly_zoom.txt"
71
+
72
+ demo = gr.Interface(
73
+ fn=run_inference,
74
+ inputs=[
75
+ gr.Image(label="Input Image", type="pil"),
76
+ gr.Textbox(label="Text Prompt", value="A cinematic dolly zoom shot of a futuristic cityscape"),
77
+ gr.Textbox(label="Camera Trajectory File Path", value=example_cam),
78
+ ],
79
+ outputs=gr.Video(label="Generated Video"),
80
+ title="FloVD-CogVideoX 🌠",
81
+ description="Upload an image, enter a text prompt and a camera trajectory file path to generate a controlled video using CogVideoX + optical flow.",
82
+ examples=[[example_image, "A beautiful sunrise over a mountain range", example_cam]]
83
+ )
84
+
85
+ if __name__ == "__main__":
86
+ demo.launch()