rahul7star commited on
Commit
33bdb5b
·
verified ·
1 Parent(s): 7f14f6f

Update app_t2v.py

Browse files
Files changed (1) hide show
  1. app_t2v.py +62 -52
app_t2v.py CHANGED
@@ -1,55 +1,58 @@
 
1
  import os
2
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
3
 
 
 
4
  import torch
 
 
5
  import gradio as gr
6
  import tempfile
7
  import random
8
  import numpy as np
9
- import spaces
10
- from diffusers import WanPipeline, AutoencoderKLWan
11
- from diffusers.utils import export_to_video
12
 
13
  # Constants
14
  MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
15
- MAX_SEED = np.iinfo(np.int32).max
16
  FIXED_FPS = 16
17
- DEFAULT_NEGATIVE_PROMPT = (
 
 
 
 
 
 
 
18
  "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
19
  "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
20
  "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
21
  )
22
 
23
- # Setup
24
- dtype = torch.float16 # using float16 for broader compatibility
25
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
- # Load model components on correct device
28
- vae = AutoencoderKLWan.from_pretrained(
29
- MODEL_ID, subfolder="vae", torch_dtype=torch.float32
30
- ).to(device)
31
 
32
- pipe = WanPipeline.from_pretrained(
33
- MODEL_ID, vae=vae, torch_dtype=dtype
34
- ).to(device)
35
 
36
- # Warm-up call to reduce cold-start latency
37
  _ = pipe(
38
  prompt="warmup",
39
- negative_prompt=DEFAULT_NEGATIVE_PROMPT,
40
  height=512,
41
  width=768,
42
  num_frames=8,
43
  num_inference_steps=2,
44
- generator=torch.Generator(device=device).manual_seed(0),
45
  ).frames[0]
46
 
47
- # Estimate duration for Hugging Face Spaces GPU usage
48
- def get_duration(prompt, negative_prompt, height, width, num_frames, guidance_scale, guidance_scale_2, num_steps, seed, randomize_seed):
49
- return int(num_steps * 15)
50
 
51
  @spaces.GPU(duration=get_duration)
52
- def generate_video(
53
  prompt,
54
  negative_prompt,
55
  height,
@@ -57,57 +60,64 @@ def generate_video(
57
  num_frames,
58
  guidance_scale,
59
  guidance_scale_2,
60
- num_steps,
61
  seed,
62
- randomize_seed
 
63
  ):
64
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
65
  generator = torch.Generator(device=device).manual_seed(current_seed)
66
 
67
- output = pipe(
68
  prompt=prompt,
69
  negative_prompt=negative_prompt,
70
- height=height,
71
- width=width,
72
- num_frames=num_frames,
73
- guidance_scale=guidance_scale,
74
- guidance_scale_2=guidance_scale_2,
75
- num_inference_steps=num_steps,
76
  generator=generator,
77
  ).frames[0]
78
 
79
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
80
- export_to_video(output, tmpfile.name, fps=FIXED_FPS)
81
  return tmpfile.name, current_seed
82
 
83
  # Gradio UI
84
  with gr.Blocks() as demo:
85
- gr.Markdown("## 🎬 Wan2.2 Text-to-Video Generator with Hugging Face Spaces GPU")
86
 
87
  with gr.Row():
88
  with gr.Column():
89
- prompt = gr.Textbox(label="Prompt", value="Two anthropomorphic cats in comfy boxing gear fight intensely.")
90
- negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT, lines=3)
91
- height = gr.Slider(360, 1024, value=720, step=16, label="Height")
92
- width = gr.Slider(360, 1920, value=1280, step=16, label="Width")
93
- num_frames = gr.Slider(8, 81, value=81, step=1, label="Number of Frames")
94
- num_steps = gr.Slider(10, 60, value=40, step=1, label="Inference Steps")
95
- guidance_scale = gr.Slider(1.0, 10.0, value=4.0, step=0.5, label="Guidance Scale")
96
- guidance_scale_2 = gr.Slider(1.0, 10.0, value=3.0, step=0.5, label="Guidance Scale 2")
97
- seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
98
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
99
-
100
- generate_button = gr.Button("🎥 Generate Video")
 
 
101
 
102
  with gr.Column():
103
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
104
- final_seed_display = gr.Number(label="Used Seed", interactive=False)
 
 
 
 
 
 
 
 
105
 
106
- generate_button.click(
107
- fn=generate_video,
108
- inputs=[prompt, negative_prompt, height, width, num_frames, guidance_scale, guidance_scale_2, num_steps, seed, randomize_seed],
109
- outputs=[video_output, final_seed_display],
110
- )
111
 
112
  if __name__ == "__main__":
113
- demo.queue().launch()
 
1
+ # PyTorch nightly for CUDA compatibility
2
  import os
3
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
 
5
+ # Imports
6
+ import spaces
7
  import torch
8
+ from diffusers import WanPipeline, AutoencoderKLWan
9
+ from diffusers.utils import export_to_video
10
  import gradio as gr
11
  import tempfile
12
  import random
13
  import numpy as np
 
 
 
14
 
15
  # Constants
16
  MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
 
17
  FIXED_FPS = 16
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ DEFAULT_HEIGHT = 720
20
+ DEFAULT_WIDTH = 1280
21
+ MAX_FRAMES = 81
22
+
23
+ # Prompts
24
+ default_prompt_t2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
25
+ default_negative_prompt = (
26
  "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
27
  "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
28
  "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
29
  )
30
 
31
+ # Load pipeline
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
34
 
35
+ vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32).to(device)
 
 
 
36
 
37
+ pipe = WanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=dtype).to(device)
 
 
38
 
39
+ # Optional: warm-up
40
  _ = pipe(
41
  prompt="warmup",
42
+ negative_prompt=default_negative_prompt,
43
  height=512,
44
  width=768,
45
  num_frames=8,
46
  num_inference_steps=2,
47
+ generator=torch.Generator(device=device).manual_seed(0)
48
  ).frames[0]
49
 
50
+ # Space-aware duration helper
51
+ def get_duration(prompt, negative_prompt, height, width, num_frames, guidance_scale, guidance_scale_2, steps, seed, randomize_seed, progress):
52
+ return int(steps * 15)
53
 
54
  @spaces.GPU(duration=get_duration)
55
+ def generate_t2v(
56
  prompt,
57
  negative_prompt,
58
  height,
 
60
  num_frames,
61
  guidance_scale,
62
  guidance_scale_2,
63
+ steps,
64
  seed,
65
+ randomize_seed,
66
+ progress=gr.Progress(track_tqdm=True),
67
  ):
68
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
69
  generator = torch.Generator(device=device).manual_seed(current_seed)
70
 
71
+ output_frames = pipe(
72
  prompt=prompt,
73
  negative_prompt=negative_prompt,
74
+ height=int(height),
75
+ width=int(width),
76
+ num_frames=int(num_frames),
77
+ guidance_scale=float(guidance_scale),
78
+ guidance_scale_2=float(guidance_scale_2),
79
+ num_inference_steps=int(steps),
80
  generator=generator,
81
  ).frames[0]
82
 
83
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
84
+ export_to_video(output_frames, tmpfile.name, fps=FIXED_FPS)
85
  return tmpfile.name, current_seed
86
 
87
  # Gradio UI
88
  with gr.Blocks() as demo:
89
+ gr.Markdown("## 🎬 Wan 2.2 T2V: Text-to-Video via Wan-AI")
90
 
91
  with gr.Row():
92
  with gr.Column():
93
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v)
94
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
95
+ height_slider = gr.Slider(360, 1024, step=16, value=DEFAULT_HEIGHT, label="Height")
96
+ width_slider = gr.Slider(360, 1920, step=16, value=DEFAULT_WIDTH, label="Width")
97
+ frames_slider = gr.Slider(8, MAX_FRAMES, value=MAX_FRAMES, step=1, label="Frames")
98
+
99
+ with gr.Accordion("Advanced Settings", open=False):
100
+ guidance_slider = gr.Slider(0.0, 20.0, step=0.5, value=4.0, label="Guidance Scale")
101
+ guidance2_slider = gr.Slider(0.0, 20.0, step=0.5, value=3.0, label="Guidance Scale 2")
102
+ steps_slider = gr.Slider(1, 60, step=1, value=40, label="Inference Steps")
103
+ seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed", interactive=True)
104
+ randomize_seed_check = gr.Checkbox(label="Randomize Seed", value=True)
105
+
106
+ generate_button = gr.Button("🎥 Generate Video", variant="primary")
107
 
108
  with gr.Column():
109
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
110
+ used_seed = gr.Number(label="Used Seed", interactive=False)
111
+
112
+ inputs = [
113
+ prompt_input, negative_prompt_input,
114
+ height_slider, width_slider,
115
+ frames_slider,
116
+ guidance_slider, guidance2_slider,
117
+ steps_slider, seed_slider, randomize_seed_check
118
+ ]
119
 
120
+ generate_button.click(fn=generate_t2v, inputs=inputs, outputs=[video_output, used_seed])
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
+ demo.queue().launch(mcp_server=True)