rahul7star commited on
Commit
c7de241
·
verified ·
1 Parent(s): b60eb30

Update app_t2v.py

Browse files
Files changed (1) hide show
  1. app_t2v.py +168 -77
app_t2v.py CHANGED
@@ -1,123 +1,214 @@
1
- # PyTorch nightly for CUDA compatibility
2
  import os
3
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
 
5
- # Imports
6
  import spaces
7
  import torch
8
- from diffusers import WanPipeline, AutoencoderKLWan
9
- from diffusers.utils import export_to_video
 
10
  import gradio as gr
11
  import tempfile
12
- import random
13
  import numpy as np
 
 
 
 
 
 
 
 
 
14
 
15
- # Constants
16
- MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
17
- FIXED_FPS = 16
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
- DEFAULT_HEIGHT = 720
20
- DEFAULT_WIDTH = 1280
21
- MAX_FRAMES = 81
22
-
23
- # Prompts
24
- default_prompt_t2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
25
- default_negative_prompt = (
26
- "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
27
- "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
28
- "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
 
31
- # Load pipeline
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
34
 
35
- vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32).to(device)
 
36
 
37
- pipe = WanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=dtype).to(device)
38
 
39
- # Optional: warm-up
40
- _ = pipe(
41
- prompt="warmup",
42
- negative_prompt=default_negative_prompt,
43
- height=512,
44
- width=768,
45
- num_frames=8,
46
- num_inference_steps=2,
47
- generator=torch.Generator(device=device).manual_seed(0)
48
- ).frames[0]
49
 
50
- # Space-aware duration helper
51
- def get_duration(prompt, negative_prompt, height, width, num_frames, guidance_scale, guidance_scale_2, steps, seed, randomize_seed, progress):
52
- return int(steps * 15)
53
 
54
- @spaces.GPU(duration=get_duration)
55
- def generate_t2v(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  prompt,
57
  negative_prompt,
58
- height,
59
- width,
60
  num_frames,
61
  guidance_scale,
62
- guidance_scale_2,
63
  steps,
64
  seed,
65
  randomize_seed,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  progress=gr.Progress(track_tqdm=True),
67
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
69
- generator = torch.Generator(device=device).manual_seed(current_seed)
70
 
71
- output_frames = pipe(
 
72
  prompt=prompt,
73
  negative_prompt=negative_prompt,
74
- height=int(height),
75
- width=int(width),
76
- num_frames=int(num_frames),
77
  guidance_scale=float(guidance_scale),
78
- guidance_scale_2=float(guidance_scale_2),
79
  num_inference_steps=int(steps),
80
- generator=generator,
81
  ).frames[0]
82
 
83
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
84
- export_to_video(output_frames, tmpfile.name, fps=FIXED_FPS)
85
- return tmpfile.name, current_seed
86
 
87
- # Gradio UI
88
- with gr.Blocks() as demo:
89
- gr.Markdown("## 🎬 Wan 2.2 T2V: Text-to-Video via Wan-AI")
90
 
 
 
 
 
 
91
  with gr.Row():
92
  with gr.Column():
93
- prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v)
94
- negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
95
- height_slider = gr.Slider(360, 1024, step=16, value=DEFAULT_HEIGHT, label="Height")
96
- width_slider = gr.Slider(360, 1920, step=16, value=DEFAULT_WIDTH, label="Width")
97
- frames_slider = gr.Slider(8, MAX_FRAMES, value=MAX_FRAMES, step=1, label="Frames")
98
-
99
  with gr.Accordion("Advanced Settings", open=False):
100
- guidance_slider = gr.Slider(0.0, 20.0, step=0.5, value=4.0, label="Guidance Scale")
101
- guidance2_slider = gr.Slider(0.0, 20.0, step=0.5, value=3.0, label="Guidance Scale 2")
102
- steps_slider = gr.Slider(1, 60, step=1, value=40, label="Inference Steps")
103
- seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed", interactive=True)
104
- randomize_seed_check = gr.Checkbox(label="Randomize Seed", value=True)
105
-
106
- generate_button = gr.Button("🎥 Generate Video", variant="primary")
107
 
 
108
  with gr.Column():
109
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
110
- used_seed = gr.Number(label="Used Seed", interactive=False)
111
-
112
- inputs = [
113
- prompt_input, negative_prompt_input,
114
- height_slider, width_slider,
115
- frames_slider,
116
- guidance_slider, guidance2_slider,
117
- steps_slider, seed_slider, randomize_seed_check
118
  ]
119
-
120
- generate_button.click(fn=generate_t2v, inputs=inputs, outputs=[video_output, used_seed])
 
 
 
 
 
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
- demo.queue().launch(mcp_server=True)
 
1
+ # PyTorch 2.8 (temporary hack)
2
  import os
3
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
 
5
+ # Actual demo code
6
  import spaces
7
  import torch
8
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
9
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
10
+ from diffusers.utils.export_utils import export_to_video
11
  import gradio as gr
12
  import tempfile
 
13
  import numpy as np
14
+ from PIL import Image
15
+ import random
16
+
17
+ from optimization import optimize_pipeline_
18
+
19
+
20
+ #MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
21
+
22
+ MODEL_ID = "Runware/Wan2.2-T2V-A14B"
23
 
24
+
25
+
26
+ LANDSCAPE_WIDTH = 832
27
+ LANDSCAPE_HEIGHT = 480
28
  MAX_SEED = np.iinfo(np.int32).max
29
+
30
+ FIXED_FPS = 24
31
+ MIN_FRAMES_MODEL = 13
32
+ MAX_FRAMES_MODEL = 121
33
+ NUM_FRAMES_DEFAULT = 81
34
+
35
+
36
+ pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID,
37
+ transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
38
+ subfolder='transformer',
39
+ torch_dtype=torch.bfloat16,
40
+ device_map='cuda',
41
+ ),
42
+ transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
43
+ subfolder='transformer_2',
44
+ torch_dtype=torch.bfloat16,
45
+ device_map='cuda',
46
+ ),
47
+ torch_dtype=torch.bfloat16,
48
+ ).to('cuda')
49
+
50
+
51
+ optimize_pipeline_(pipe,
52
+ image=Image.new('RGB', (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT)),
53
+ prompt='prompt',
54
+ height=LANDSCAPE_HEIGHT,
55
+ width=LANDSCAPE_WIDTH,
56
+ num_frames=MAX_FRAMES_MODEL,
57
  )
58
 
 
 
 
59
 
60
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
61
+ default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
62
 
 
63
 
64
+ def resize_image(image: Image.Image) -> Image.Image:
65
+ if image.height > image.width:
66
+ transposed = image.transpose(Image.Transpose.ROTATE_90)
67
+ resized = resize_image_landscape(transposed)
68
+ return resized.transpose(Image.Transpose.ROTATE_270)
69
+ return resize_image_landscape(image)
 
 
 
 
70
 
 
 
 
71
 
72
+ def resize_image_landscape(image: Image.Image) -> Image.Image:
73
+ target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT
74
+ width, height = image.size
75
+ in_aspect = width / height
76
+ if in_aspect > target_aspect:
77
+ new_width = round(height * target_aspect)
78
+ left = (width - new_width) // 2
79
+ image = image.crop((left, 0, left + new_width, height))
80
+ else:
81
+ new_height = round(width / target_aspect)
82
+ top = (height - new_height) // 2
83
+ image = image.crop((0, top, width, top + new_height))
84
+ return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
85
+
86
+ def get_duration(
87
+ input_image,
88
  prompt,
89
  negative_prompt,
 
 
90
  num_frames,
91
  guidance_scale,
 
92
  steps,
93
  seed,
94
  randomize_seed,
95
+ progress,
96
+ ):
97
+ forward_duration_base = 8
98
+ forward_duration = forward_duration_base * (num_frames / NUM_FRAMES_DEFAULT)**1.5
99
+ forward_count = 2 if guidance_scale > 1 else 1
100
+ return 10 + steps * forward_count * forward_duration
101
+
102
+ @spaces.GPU(duration=get_duration)
103
+ def generate_video(
104
+ input_image,
105
+ prompt,
106
+ negative_prompt=default_negative_prompt,
107
+ num_frames = NUM_FRAMES_DEFAULT,
108
+ guidance_scale = 1,
109
+ steps = 28,
110
+ seed = 42,
111
+ randomize_seed = False,
112
  progress=gr.Progress(track_tqdm=True),
113
  ):
114
+ """
115
+ Generate a video from an input image using the Wan 2.1 I2V model with CausVid LoRA.
116
+
117
+ This function takes an input image and generates a video animation based on the provided
118
+ prompt and parameters. It uses the Wan 2.1 14B Image-to-Video model with CausVid LoRA
119
+ for fast generation in 4-8 steps.
120
+
121
+ Args:
122
+ input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
123
+ prompt (str): Text prompt describing the desired animation or motion.
124
+ negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
125
+ Defaults to default_negative_prompt (contains unwanted visual artifacts).
126
+ num_frames (int, optional): Number of frames.
127
+ Defaults to MAX_FRAMES_MODEL
128
+ guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
129
+ Defaults to 1.0. Range: 0.0-20.0.
130
+ steps (int, optional): Number of inference steps. More steps = higher quality but slower.
131
+ Defaults to 4. Range: 1-30.
132
+ seed (int, optional): Random seed for reproducible results. Defaults to 42.
133
+ Range: 0 to MAX_SEED (2147483647).
134
+ randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
135
+ Defaults to False.
136
+ progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
137
+
138
+ Returns:
139
+ tuple: A tuple containing:
140
+ - video_path (str): Path to the generated video file (.mp4)
141
+ - current_seed (int): The seed used for generation (useful when randomize_seed=True)
142
+
143
+ Raises:
144
+ gr.Error: If input_image is None (no image uploaded).
145
+
146
+ Note:
147
+ - The function automatically resizes the input image to the target dimensions
148
+ - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
149
+ - The function uses GPU acceleration via the @spaces.GPU decorator
150
+ """
151
+ if input_image is None:
152
+ raise gr.Error("Please upload an input image.")
153
+
154
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
155
+ resized_image = resize_image(input_image)
156
 
157
+ output_frames_list = pipe(
158
+ image=resized_image,
159
  prompt=prompt,
160
  negative_prompt=negative_prompt,
161
+ height=resized_image.height,
162
+ width=resized_image.width,
163
+ num_frames=num_frames,
164
  guidance_scale=float(guidance_scale),
 
165
  num_inference_steps=int(steps),
166
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
167
  ).frames[0]
168
 
169
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
170
+ video_path = tmpfile.name
 
171
 
172
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
 
 
173
 
174
+ return video_path, current_seed
175
+
176
+ with gr.Blocks() as demo:
177
+ gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA")
178
+ gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors) and is compatible with 🧨 diffusers")
179
  with gr.Row():
180
  with gr.Column():
181
+ input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
182
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
183
+ num_frames_input = gr.Slider(minimum=MIN_FRAMES_MODEL, maximum=MAX_FRAMES_MODEL, step=1, value=NUM_FRAMES_DEFAULT, label="Frames")
184
+
 
 
185
  with gr.Accordion("Advanced Settings", open=False):
186
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
187
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
188
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
189
+ steps_slider = gr.Slider(minimum=1, maximum=40, step=1, value=28, label="Inference Steps")
190
+ guidance_scale_input = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale")
 
 
191
 
192
+ generate_button = gr.Button("Generate Video", variant="primary")
193
  with gr.Column():
194
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
195
+
196
+ ui_inputs = [
197
+ input_image_component, prompt_input,
198
+ negative_prompt_input, num_frames_input,
199
+ guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
 
 
 
200
  ]
201
+ generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
202
+
203
+ gr.Examples(
204
+ examples=[
205
+ [
206
+ "wan_i2v_input.JPG",
207
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
208
+ ],
209
+ ],
210
+ inputs=[input_image_component, prompt_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
211
+ )
212
 
213
  if __name__ == "__main__":
214
+ demo.queue().launch(mcp_server=True)