roll-ai commited on
Commit
ded78ff
Β·
verified Β·
1 Parent(s): 7e06d6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -44
app.py CHANGED
@@ -30,57 +30,53 @@ def download_weights():
30
  print(f"βœ… Already exists: {save_path}")
31
 
32
  download_weights()
 
 
 
 
33
 
34
- # =========================================
35
- # 2. Import and load FloVD pipeline
36
- # =========================================
37
-
38
- from inference.flovd_demo import load_pipeline, generate_video
39
-
40
- pipeline = load_pipeline(
41
- fvsm_path="ckpt/FVSM/FloVD_FVSM_Controlnet.pt",
42
- omsm_path="ckpt/OMSM",
43
- depth_path="ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth",
44
- device="cuda" if torch.cuda.is_available() else "cpu"
45
- )
46
-
47
- # =========================================
48
- # 3. Inference Function
49
- # =========================================
50
 
51
- def run_inference(image: Image.Image, prompt: str, cam_traj_path: str):
52
- print("πŸš€ Running inference...")
53
- output_path = generate_video(
54
- image=image,
55
  prompt=prompt,
56
- cam_traj=cam_traj_path,
57
- pipeline=pipeline,
 
 
58
  num_frames=49,
59
  fps=16,
 
 
 
 
 
60
  controlnet_guidance_end=0.4,
61
- flow_scale=(60, 36)
 
 
 
 
 
62
  )
63
- return output_path
64
-
65
- # =========================================
66
- # 4. Gradio UI
67
- # =========================================
68
 
69
- example_image = "assets/manual_poses/example_image.jpg"
70
- example_cam = "assets/cam_trajectory/dolly_zoom.txt"
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- demo = gr.Interface(
73
- fn=run_inference,
74
- inputs=[
75
- gr.Image(label="Input Image", type="pil"),
76
- gr.Textbox(label="Text Prompt", value="A cinematic dolly zoom shot of a futuristic cityscape"),
77
- gr.Textbox(label="Camera Trajectory File Path", value=example_cam),
78
- ],
79
- outputs=gr.Video(label="Generated Video"),
80
- title="FloVD-CogVideoX 🌠",
81
- description="Upload an image, enter a text prompt and a camera trajectory file path to generate a controlled video using CogVideoX + optical flow.",
82
- examples=[[example_image, "A beautiful sunrise over a mountain range", example_cam]]
83
- )
84
 
85
- if __name__ == "__main__":
86
- demo.launch()
 
30
  print(f"βœ… Already exists: {save_path}")
31
 
32
  download_weights()
33
+ import gradio as gr
34
+ import torch
35
+ import os
36
+ from inference_script import generate_video # Assuming your script is saved as inference_script.py
37
 
38
+ def run_inference(prompt, image, pose_type, speed, use_flow_integration, cam_pose_name):
39
+ os.makedirs("input_images", exist_ok=True)
40
+ image_path = "input_images/input_image.png"
41
+ image.save(image_path)
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ generate_video(
 
 
 
44
  prompt=prompt,
45
+ image_path=image_path,
46
+ fvsm_path="./ckpt/FVSM", # Expected to be downloaded from HF dataset
47
+ omsm_path="./ckpt/OMSM", # Expected to be downloaded from HF dataset
48
+ output_path="./outputs",
49
  num_frames=49,
50
  fps=16,
51
+ width=None,
52
+ height=None,
53
+ seed=42,
54
+ guidance_scale=6.0,
55
+ dtype=torch.float16,
56
  controlnet_guidance_end=0.4,
57
+ use_dynamic_cfg=False,
58
+ pose_type=pose_type,
59
+ speed=float(speed),
60
+ use_flow_integration=use_flow_integration,
61
+ cam_pose_name=cam_pose_name,
62
+ depth_ckpt_path="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth"
63
  )
64
+ return f"./outputs/generated_videos/{prompt[:30].strip().replace(' ', '_')}_{cam_pose_name or 'default'}.mp4"
 
 
 
 
65
 
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown("## πŸŽ₯ FloVD: Optical Flow + CogVideoX Video Generation")
68
+ with gr.Row():
69
+ with gr.Column():
70
+ prompt = gr.Textbox(label="Prompt", value="A girl riding a bicycle through a park.")
71
+ image = gr.Image(type="pil", label="Input Image")
72
+ pose_type = gr.Radio(choices=["manual", "re10k"], value="manual", label="Camera Pose Type")
73
+ cam_pose_name = gr.Textbox(label="Camera Trajectory Name", placeholder="e.g. zoom_in, tilt_up")
74
+ speed = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.5, label="Speed")
75
+ use_flow_integration = gr.Checkbox(label="Use Flow Integration", value=False)
76
+ submit = gr.Button("Generate Video")
77
+ with gr.Column():
78
+ output_video = gr.Video(label="Generated Video")
79
 
80
+ submit.click(fn=run_inference, inputs=[prompt, image, pose_type, speed, use_flow_integration, cam_pose_name], outputs=output_video)
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ demo.launch()